From d40982c1fc3919d7aef9b5be6cb28ebfd9b079db Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 14 Sep 2023 12:27:53 -0400 Subject: [PATCH 001/105] Add init --- src/aspire/abinitio/commonline_d2.py | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 src/aspire/abinitio/commonline_d2.py diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py new file mode 100644 index 0000000000..9a72ef8ce4 --- /dev/null +++ b/src/aspire/abinitio/commonline_d2.py @@ -0,0 +1,61 @@ +import logging + +import numpy as np + +from aspire.abinitio import CLOrient3D + +logger = logging.getLogger(__name__) + + +class CLSymmetryD2(CLOrient3D): + """ + Define a class to estimate 3D orientations using common lines methods for + molecules with D2 (dihedral) symmetry. + + The related publications are: + E. Rosen and Y. Shkolnisky, + Common lines ab-initio reconstruction of D2-symmetric molecules, + """ + + def __init__( + self, + src, + n_rad=None, + n_theta=None, + max_shift=0.15, + shift_step=1, + grid_res=1200, + inplane_res=5, + eq_min_dist=7, + seed=None, + ): + """ + Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. + + :param src: The source object of 2D denoised or class-averaged images with metadata + :param n_rad: The number of points in the radial direction + :param n_theta: The number of points in the theta direction + :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. + :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. + :param grid_res: Number of sampling points on sphere for projetion directions. + These are generated using the Saaf - Kuijlaars algoithm. Default value is 1200. + :param inplane_res: The sampling resolution of in-plane rotations for each + projetion direction. Default value is 5. + :param eq_min_dist: Width of strip around equator projection directions from + which we DO NOT sample directions. Default value is 7. + :param seed: Optional seed for RNG. + """ + + super().__init__( + src, + n_rad=n_rad, + n_theta=n_theta, + max_shift=max_shift, + shift_step=shift_step, + ) + + self.grid_res = grid_res + self.inplane_res = inplane_res + self.eq_min_dist = eq_min_dist + self.seed = seed + From e5c7ddd7509375e88d42dbc4f4c8a4a54a5ceea3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 Sep 2023 10:53:13 -0400 Subject: [PATCH 002/105] saff_kuijlaars shpere points. --- src/aspire/abinitio/commonline_d2.py | 66 +++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 9a72ef8ce4..48cd717c22 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -15,19 +15,20 @@ class CLSymmetryD2(CLOrient3D): The related publications are: E. Rosen and Y. Shkolnisky, Common lines ab-initio reconstruction of D2-symmetric molecules, + SIAM Journal on Imaging Sciences, volume 13-4, p. 1898-1994, 2020 """ def __init__( - self, - src, - n_rad=None, - n_theta=None, - max_shift=0.15, - shift_step=1, - grid_res=1200, - inplane_res=5, - eq_min_dist=7, - seed=None, + self, + src, + n_rad=None, + n_theta=None, + max_shift=0.15, + shift_step=1, + grid_res=1200, + inplane_res=5, + eq_min_dist=7, + seed=None, ): """ Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. @@ -58,4 +59,47 @@ def __init__( self.inplane_res = inplane_res self.eq_min_dist = eq_min_dist self.seed = seed - + + def estimate_rotations(self): + """ + Estimate rotation matrices for molecules with C3 or C4 symmetry. + + :return: Array of rotation matrices, size n_imgx3x3. + """ + pass + + def generate_lookup_data(self): + """ + Generate candidate relative rotations and corresponding common line indices. + """ + pass + + def saff_kuijlaars(self, N): + """ + Generates N vertices on the unit sphere that are approximately evenly distributed. + + This implements the recommended algorithm in spherical coordinates + (theta, phi) according to "Distributing many points on a sphere" + by E.B. Saff and A.B.J. Kuijlaars, Mathematical Intelligencer 19.1 + (1997) 5--11. + + :param N: Number of vertices to generate. + + :return: Nx3 array of vertices in cartesian coordinates. + """ + k = np.arange(1, N + 1) + h = -1 + 2 * (k - 1) / (N - 1) + theta = np.arccos(h) + phi = np.zeros(N) + + for i in range(1, N - 1): + phi[i] = (phi[i - 1] + 3.6 / (np.sqrt(N * (1 - h[i] ** 2)))) % (2 * np.pi) + + # Spherical coordinates + x = np.sin(theta) * np.cos(phi) + y = np.sin(theta) * np.sin(phi) + z = np.cos(theta) + + mesh = np.column_stack((x, y, z)) + + return mesh From 7fa0997e17cd453155ad9d5bc548d9d19ca6e99b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 19 Sep 2023 12:01:19 -0400 Subject: [PATCH 003/105] saff_kuijlaars and mark_equators partial. --- src/aspire/abinitio/commonline_d2.py | 72 ++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 48cd717c22..5ccbbb266c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -39,7 +39,7 @@ def __init__( :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. :param grid_res: Number of sampling points on sphere for projetion directions. - These are generated using the Saaf - Kuijlaars algoithm. Default value is 1200. + These are generated using the Saaf - Kuijlaars algorithm. Default value is 1200. :param inplane_res: The sampling resolution of in-plane rotations for each projetion direction. Default value is 5. :param eq_min_dist: Width of strip around equator projection directions from @@ -72,9 +72,28 @@ def generate_lookup_data(self): """ Generate candidate relative rotations and corresponding common line indices. """ - pass - - def saff_kuijlaars(self, N): + # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter + # of sphere because of D2 symmetry redundancy. + sphere_grid = self.saff_kuijlaars(self.grid_res) + octant1_mask = np.all(sphere_grid > 0, axis=1) + octant2_mask = ( + (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) + ) + sphere_grid1 = sphere_grid[octant1_mask] + sphere_grid2 = sphere_grid[octant2_mask] + + # Mark Equator Directions. + # Common lines between projection directions which are perpendicular to + # symmetry axes (equator images) have common line degeneracies. Two images + # taken from directions on the same great circle which is perpendicular to + # some symmetry axis only have 2 common lines instead of 4, and must be + # treated separately. + # We detect such directions by taking a strip of radius + # eq_filter_angle about the 3 great circles perpendicular to the symmetry + # axes of D2 (i.e to X,Y and Z axes). + + @staticmethod + def saff_kuijlaars(N): """ Generates N vertices on the unit sphere that are approximately evenly distributed. @@ -103,3 +122,48 @@ def saff_kuijlaars(self, N): mesh = np.column_stack((x, y, z)) return mesh + + @staticmethod + def mark_equators(sphere_grid, eq_filter_angle): + """ + :param sphere_grid: Nx3 array of vertices in cartesian coordinates. + :param eq_filter_angle: Angular distance from equator to be marked as + an equator point. + + :return: Indices of points on sphere whose distance from one of + the equators is < eq_filter angle. + """ + # Project each vector onto xy, xz, yz planes and measure angular distance + # from each plane. + n_rots = len(sphere_grid) + angular_dists = np.zeros(3, n_rots, dtype=sphere_grid.dtype) + + proj_xy = sphere_grid.copy() + proj_xy[:, 2] = 0 + proj_xy /= np.linalg.norm(proj_xy, axis=1)[:, None] + angular_dists[0] = np.sum(sphere_grid * proj_xy, axis=-1) + + proj_xz = sphere_grid.copy() + proj_xz[:, 1] = 0 + proj_xz /= np.linalg.norm(proj_xz, axis=1)[:, None] + angular_dists[1] = np.sum(sphere_grid * proj_xz, axis=-1) + + proj_yz = sphere_grid.copy() + proj_yz[:, 0] = 0 + proj_yz /= np.linalg.norm(proj_yz, axis=1)[:, None] + angular_dists[2] = np.sum(sphere_grid * proj_yz, axis=-1) + + # Mark points close to equator (within eq_filter_angle). + eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) + n_eqs_close = np.sum(angular_dists > eq_min_dist, axis=0) + eq_mask = n_eqs_close > 0 + + # Classify equators. + # 1 -> z equator + # 2 -> y equator + # 3 -> x equator + # 4 -> z top view, ie. both x and y equator + # 5 -> y top view, ie. both x and z equator + # 6 -> x top view, ie. both y and z equator + eq_class = np.zeros(n_rots) + top_view_mask = n_eqs_close > 1 From 2ea41eb4f52a27d47a1eb07d22bae37e9f305475 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Sep 2023 15:03:38 -0400 Subject: [PATCH 004/105] Mark top views and equators. Generate inplane rotations. --- src/aspire/abinitio/commonline_d2.py | 110 ++++++++++++++++++++++----- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5ccbbb266c..cb0de5958c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -3,6 +3,7 @@ import numpy as np from aspire.abinitio import CLOrient3D +from aspire.utils import Rotation logger = logging.getLogger(__name__) @@ -91,6 +92,30 @@ def generate_lookup_data(self): # We detect such directions by taking a strip of radius # eq_filter_angle about the 3 great circles perpendicular to the symmetry # axes of D2 (i.e to X,Y and Z axes). + eq_mask1, top_view_mask1 = self.mark_equators(sphere_grid1, self.eq_min_dist) + eq_mask2, top_view_mask2 = self.mark_equators(sphere_grid2, self.eq_min_dist) + + # Mark Top View Directions. + # A Top view projection image is taken from the direction of one of the + # symmetry axes. Since all symmetry axes of D2 molecules are perpendicular + # this means that such an image is an equator with repect to both symmetry + # axes which are perpendicular to the direction of the symmetry axis from + # which the image was made, e.g. if the image was formed by projecting in + # the direction of the X (symmetry) axis, then it is an equator with + # respect to both Y and Z symmetry axes (it's direction is the + # interesection of 2 great circles perpendicular to Y and Z axes). + # Such images have severe degeneracies. A pair of Top View images (taken + # from different directions or a Top View and equator image only have a + # single common line. A top view and a regular non-equator image only have + # two common lines. + + # Remove top views from sphere grids and update equator masks. + sphere_grid1 = sphere_grid1[~top_view_mask1] + sphere_grid2 = sphere_grid2[~top_view_mask2] + eq_mask1 = eq_mask1[~top_view_mask1] + eq_mask2 = eq_mask2[~top_view_mask2] + + # Generate in-plane rotations for each grid point on the sphere. @staticmethod def saff_kuijlaars(N): @@ -135,35 +160,82 @@ def mark_equators(sphere_grid, eq_filter_angle): """ # Project each vector onto xy, xz, yz planes and measure angular distance # from each plane. - n_rots = len(sphere_grid) - angular_dists = np.zeros(3, n_rots, dtype=sphere_grid.dtype) + eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) + # Mask for z-axis equator views. proj_xy = sphere_grid.copy() proj_xy[:, 2] = 0 proj_xy /= np.linalg.norm(proj_xy, axis=1)[:, None] - angular_dists[0] = np.sum(sphere_grid * proj_xy, axis=-1) + ang_dists_xy = np.sum(sphere_grid * proj_xy, axis=-1) + z_eq_mask = ang_dists_xy > eq_min_dist + # Mask for y-axis equator views. proj_xz = sphere_grid.copy() proj_xz[:, 1] = 0 proj_xz /= np.linalg.norm(proj_xz, axis=1)[:, None] - angular_dists[1] = np.sum(sphere_grid * proj_xz, axis=-1) + ang_dists_xz = np.sum(sphere_grid * proj_xz, axis=-1) + y_eq_mask = ang_dists_xz > eq_min_dist + # Mask for x-axis equator views. proj_yz = sphere_grid.copy() proj_yz[:, 0] = 0 proj_yz /= np.linalg.norm(proj_yz, axis=1)[:, None] - angular_dists[2] = np.sum(sphere_grid * proj_yz, axis=-1) + ang_dists_yz = np.sum(sphere_grid * proj_yz, axis=-1) + x_eq_mask = ang_dists_yz > eq_min_dist - # Mark points close to equator (within eq_filter_angle). - eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) - n_eqs_close = np.sum(angular_dists > eq_min_dist, axis=0) - eq_mask = n_eqs_close > 0 - - # Classify equators. - # 1 -> z equator - # 2 -> y equator - # 3 -> x equator - # 4 -> z top view, ie. both x and y equator - # 5 -> y top view, ie. both x and z equator - # 6 -> x top view, ie. both y and z equator - eq_class = np.zeros(n_rots) - top_view_mask = n_eqs_close > 1 + # Mask for all views close to an equator. + eq_mask = z_eq_mask | y_eq_mask | x_eq_mask + + # Top view masks. + # A top view is a view along an axis of symmetry (ie. x, y, or z). + # A top view is also at the intersection of the two equator views + # perpendicular to the axis of symmetry. + z_top_view_mask = y_eq_mask & x_eq_mask + y_top_view_mask = z_eq_mask & x_eq_mask + x_top_view_mask = z_eq_mask & y_eq_mask + top_view_mask = z_top_view_mask | y_top_view_mask | x_top_view_mask + + return eq_mask, top_view_mask + + @staticmethod + def generate_inplane_rots(sphere_grid, d_theta): + """ + This function takes projection directions (points on the 2-sphere) and + generates rotation matrices in SO(3). The projection direction + is the 3rd column and columns 1 and 2 span the perpendicular plane. + To properly discretize SO(3), for each projection direction we generate + [2*pi/dtheta] "in-plane" rotations, of the plane + perpendicular to this direction. This is done by generating one rotation + for each direction and then multiplying on the right by a rotation about + the Z-axis by k*dtheta degrees, k=0...2*pi/dtheta-1. + + :param sphere_grid: A set of points on the 2-sphere. + :param d_theta: Resolution for in-plane rotations (in degrees) + :returns: 4D array of rotations of size len(sphere_grid) x n_inplane_rots x 3 x 3. + """ + dtype = sphere_grid.dtype + # Generate one rotation for each point on the sphere. + n_rots = len(sphere_grid) + Ri2 = np.column_stack((-sphere_grid[:, 2], sphere_grid[:, 1], np.zeros(n_rots))) + Ri2 /= np.linalg.norm(Ri2, axis=1)[:, None] + Ri1 = np.cross(Ri2, sphere_grid) + Ri1 /= np.linalg.norm(Ri1, axis=1)[:, None] + + rots_grid = np.zeros((n_rots, 3, 3), dtype=dtype) + rots_grid[:, :, 0] = Ri1 + rots_grid[:, :, 1] = Ri2 + rots_grid[:, :, 2] = sphere_grid + + # Generate in-plane rotations. + d_theta *= np.pi / 180 + inplane_rots = Rotation.about_axis( + "z", np.arange(0, 2 * np.pi, d_theta), dtype=dtype + ).matrices + n_inplane_rots = len(inplane_rots) + + # Generate in-plane rotations of rots_grid. + inplane_rotated_grid = np.zeros((n_rots, n_inplane_rots, 3, 3), dtype=dtype) + for i in range(n_rots): + inplane_rotated_grid[i] = rots_grid[i] @ inplane_rots + + return inplane_rotated_grid From 81437b488a5780d6ee2cce8c3212dd191f8ee3b7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 1 Dec 2023 15:48:19 -0500 Subject: [PATCH 005/105] match matlab MarkEquators --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_d2.py | 136 +++++++++++++++++++++++++-- 2 files changed, 128 insertions(+), 9 deletions(-) diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 9d4b0f483c..e8115ea185 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -8,5 +8,6 @@ from .commonline_c3_c4 import CLSymmetryC3C4 from .commonline_cn import CLSymmetryCn from .commonline_c2 import CLSymmetryC2 +from .commonline_d2 import CLSymmetryD2 # isort: on diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index cb0de5958c..234adc1e2c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -92,8 +92,8 @@ def generate_lookup_data(self): # We detect such directions by taking a strip of radius # eq_filter_angle about the 3 great circles perpendicular to the symmetry # axes of D2 (i.e to X,Y and Z axes). - eq_mask1, top_view_mask1 = self.mark_equators(sphere_grid1, self.eq_min_dist) - eq_mask2, top_view_mask2 = self.mark_equators(sphere_grid2, self.eq_min_dist) + eq_idx1, eq_class1 = self.mark_equators(sphere_grid1, self.eq_min_dist) + eq_idx2, eq_class2 = self.mark_equators(sphere_grid2, self.eq_min_dist) # Mark Top View Directions. # A Top view projection image is taken from the direction of one of the @@ -109,13 +109,31 @@ def generate_lookup_data(self): # single common line. A top view and a regular non-equator image only have # two common lines. - # Remove top views from sphere grids and update equator masks. - sphere_grid1 = sphere_grid1[~top_view_mask1] - sphere_grid2 = sphere_grid2[~top_view_mask2] - eq_mask1 = eq_mask1[~top_view_mask1] - eq_mask2 = eq_mask2[~top_view_mask2] + # Remove top views from sphere grids and update equator indices and classes. + sphere_grid1 = sphere_grid1[eq_class1 < 4] + sphere_grid2 = sphere_grid2[eq_class2 < 4] + eq_idx1 = eq_idx1[eq_class1 < 4] + eq_idx2 = eq_idx2[eq_class2 < 4] + eq_class1 = eq_class1[eq_class1 < 4] + eq_class2 = eq_class2[eq_class2 < 4] # Generate in-plane rotations for each grid point on the sphere. + inplane_rotated_grid1 = self.generate_inplane_rots( + sphere_grid1, self.inplane_res + ) + inplane_rotated_grid2 = self.generate_inplane_rots( + sphere_grid2, self.inplane_res + ) + + # Generate all relative rotation candidates for maximum-likelihood method. + rots = self.generate_relative_rotations( + inplane_rotated_grid1, + inplane_rotated_grid1, + eq_idx1, + eq_idx1, + eq_class1, + eq_class1, + ) @staticmethod def saff_kuijlaars(N): @@ -149,7 +167,7 @@ def saff_kuijlaars(N): return mesh @staticmethod - def mark_equators(sphere_grid, eq_filter_angle): + def mark_equators1(sphere_grid, eq_filter_angle): """ :param sphere_grid: Nx3 array of vertices in cartesian coordinates. :param eq_filter_angle: Angular distance from equator to be marked as @@ -195,7 +213,72 @@ def mark_equators(sphere_grid, eq_filter_angle): x_top_view_mask = z_eq_mask & y_eq_mask top_view_mask = z_top_view_mask | y_top_view_mask | x_top_view_mask - return eq_mask, top_view_mask + masks = { + "eq": eq_mask, + "top": top_view_mask, + "x_eq": x_eq_mask, + "y_eq": y_eq_mask, + "z_eq": z_eq_mask, + } + + return masks + + @staticmethod + def mark_equators(sphere_grid, eq_filter_angle): + """ + :param sphere_grid: Nx3 array of vertices in cartesian coordinates. + :param eq_filter_angle: Angular distance from equator to be marked as + an equator point. + + :returns: + - eq_idx, a boolean mask for equator indices. + - eq_class, n_rots length array of values indicating equator class. + """ + # Project each vector onto xy, xz, yz planes and measure angular distance + # from each plane. + n_rots = len(sphere_grid) + angular_dists = np.zeros((n_rots, 3), dtype=sphere_grid.dtype) + + # Distance from z-axis equator. + proj_xy = sphere_grid.copy() + proj_xy[:, 2] = 0 + proj_xy /= np.linalg.norm(proj_xy, axis=1)[:, None] + angular_dists[:, 0] = np.sum(sphere_grid * proj_xy, axis=-1) + + # Distance from y-axis equator. + proj_xz = sphere_grid.copy() + proj_xz[:, 1] = 0 + proj_xz /= np.linalg.norm(proj_xz, axis=1)[:, None] + angular_dists[:, 1] = np.sum(sphere_grid * proj_xz, axis=-1) + + # Distance from x-axis equator. + proj_yz = sphere_grid.copy() + proj_yz[:, 0] = 0 + proj_yz /= np.linalg.norm(proj_yz, axis=1)[:, None] + angular_dists[:, 2] = np.sum(sphere_grid * proj_yz, axis=-1) + + # Mark all views close to an equator. + eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) + n_eqs = np.sum(angular_dists > eq_min_dist, axis=1) + eq_idx = n_eqs > 0 + + # Classify equators. + # 0 -> non-equator view + # 1 -> z equator + # 2 -> y equator + # 3 -> x equator + # 4 -> z top view + # 5 -> y top view + # 6 -> x top view + eq_class = np.zeros(n_rots) + top_view_idx = n_eqs > 1 + top_view_class = np.argmin(angular_dists[top_view_idx] > eq_min_dist) + eq_class[top_view_idx] = top_view_class + 4 + eq_view_idx = n_eqs == 1 + eq_view_class = np.argmax(angular_dists[eq_view_idx] > eq_min_dist, axis=1) + eq_class[eq_view_idx] = eq_view_class + 1 + + return eq_idx, eq_class @staticmethod def generate_inplane_rots(sphere_grid, d_theta): @@ -239,3 +322,38 @@ def generate_inplane_rots(sphere_grid, d_theta): inplane_rotated_grid[i] = rots_grid[i] @ inplane_rots return inplane_rotated_grid + + @staticmethod + def generate_relative_rotations( + Ris, Rjs, Ri_eq_idx, Rj_eq_idx, Ri_eq_class, Rj_eq_class + ): + """ + :param Ris: First set of candidate rotations. + :param Rjs: Second set of candidate rotation. + :param Ri_eq_idx: + """ + n_rots_i = len(Ris) + n_rots_j = len(Rjs) + n_theta = Ris.shape[1] # Same for Rjs + + # Generate upper triangular table of indicators of all pairs which are not + # equators with respect to the same symmetry axis (named unique_pairs). + eq_table = np.outer(Ri_eq_idx, Rj_eq_idx) + in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 + eq2eq_Rij_table = np.triu(~(eq_table * in_same_class)) + + n_pairs = np.sum(eq2eq_Rij_table) + idx = 0 + cls = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 2, 4)) + + for i in range(n_rots_i): + unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] + if len(unique_pairs_i) == 0: + continue + Ri = Ris[i] + for j in unique_pairs_i: + # Compute relative rotations candidates + Rj = Rjs[j, : n_theta // 2] + import pdb + + pdb.set_trace() From 875e2337684ba7976f2d020a6d1b5e4190e9309c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 4 Dec 2023 13:59:21 -0500 Subject: [PATCH 006/105] relative rotations. --- src/aspire/abinitio/commonline_d2.py | 141 ++++++++++++++------------- 1 file changed, 75 insertions(+), 66 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 234adc1e2c..e5a95337ca 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -125,8 +125,8 @@ def generate_lookup_data(self): sphere_grid2, self.inplane_res ) - # Generate all relative rotation candidates for maximum-likelihood method. - rots = self.generate_relative_rotations( + # Generate commmon line angles induced by all relative rotation candidates. + cl_angles_1 = self.generate_relative_rotations( inplane_rotated_grid1, inplane_rotated_grid1, eq_idx1, @@ -134,6 +134,16 @@ def generate_lookup_data(self): eq_class1, eq_class1, ) + cl_angles_2 = self.generate_relative_rotations( + inplane_rotated_grid1, + inplane_rotated_grid2, + eq_idx1, + eq_idx2, + eq_class1, + eq_class2, + ) + + return cl_angles_1, cl_angles_2 @staticmethod def saff_kuijlaars(N): @@ -166,63 +176,6 @@ def saff_kuijlaars(N): return mesh - @staticmethod - def mark_equators1(sphere_grid, eq_filter_angle): - """ - :param sphere_grid: Nx3 array of vertices in cartesian coordinates. - :param eq_filter_angle: Angular distance from equator to be marked as - an equator point. - - :return: Indices of points on sphere whose distance from one of - the equators is < eq_filter angle. - """ - # Project each vector onto xy, xz, yz planes and measure angular distance - # from each plane. - eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) - - # Mask for z-axis equator views. - proj_xy = sphere_grid.copy() - proj_xy[:, 2] = 0 - proj_xy /= np.linalg.norm(proj_xy, axis=1)[:, None] - ang_dists_xy = np.sum(sphere_grid * proj_xy, axis=-1) - z_eq_mask = ang_dists_xy > eq_min_dist - - # Mask for y-axis equator views. - proj_xz = sphere_grid.copy() - proj_xz[:, 1] = 0 - proj_xz /= np.linalg.norm(proj_xz, axis=1)[:, None] - ang_dists_xz = np.sum(sphere_grid * proj_xz, axis=-1) - y_eq_mask = ang_dists_xz > eq_min_dist - - # Mask for x-axis equator views. - proj_yz = sphere_grid.copy() - proj_yz[:, 0] = 0 - proj_yz /= np.linalg.norm(proj_yz, axis=1)[:, None] - ang_dists_yz = np.sum(sphere_grid * proj_yz, axis=-1) - x_eq_mask = ang_dists_yz > eq_min_dist - - # Mask for all views close to an equator. - eq_mask = z_eq_mask | y_eq_mask | x_eq_mask - - # Top view masks. - # A top view is a view along an axis of symmetry (ie. x, y, or z). - # A top view is also at the intersection of the two equator views - # perpendicular to the axis of symmetry. - z_top_view_mask = y_eq_mask & x_eq_mask - y_top_view_mask = z_eq_mask & x_eq_mask - x_top_view_mask = z_eq_mask & y_eq_mask - top_view_mask = z_top_view_mask | y_top_view_mask | x_top_view_mask - - masks = { - "eq": eq_mask, - "top": top_view_mask, - "x_eq": x_eq_mask, - "y_eq": y_eq_mask, - "z_eq": z_eq_mask, - } - - return masks - @staticmethod def mark_equators(sphere_grid, eq_filter_angle): """ @@ -333,7 +286,6 @@ def generate_relative_rotations( :param Ri_eq_idx: """ n_rots_i = len(Ris) - n_rots_j = len(Rjs) n_theta = Ris.shape[1] # Same for Rjs # Generate upper triangular table of indicators of all pairs which are not @@ -344,7 +296,7 @@ def generate_relative_rotations( n_pairs = np.sum(eq2eq_Rij_table) idx = 0 - cls = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 2, 4)) + cls = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) for i in range(n_rots_i): unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] @@ -352,8 +304,65 @@ def generate_relative_rotations( continue Ri = Ris[i] for j in unique_pairs_i: - # Compute relative rotations candidates - Rj = Rjs[j, : n_theta // 2] - import pdb - - pdb.set_trace() + # Compute relative rotations candidates Rij = Ri.T @ Rj + Rj = Rjs[j, : (n_theta // 2)] + Rijs = np.transpose(Rj, axes=(0, 2, 1)) @ Ri[:, None] + + # Common line indices induced by Rijs + cls[idx, :, :, 0, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) + cls[idx, :, :, 0, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) + cls[idx + n_pairs, :, :, 0, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cls[idx + n_pairs, :, :, 0, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) + + # Compute relative rotations candidates Rij = Ri.T @ g1 @ Rj, + # where g1 = diag(1, -1, -1). + g1_Rj = Rj.copy() + g1_Rj[:, 1:3] = -g1_Rj[:, 1:3] + Rijs = np.transpose(g1_Rj, axes=(0, 2, 1)) @ Ri[:, None] + + cls[idx, :, :, 1, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) + cls[idx, :, :, 1, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) + cls[idx + n_pairs, :, :, 1, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cls[idx + n_pairs, :, :, 1, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) + + # Compute relative rotations candidates Rij = Ri.T @ g2 @ Rj, + # where g2 = diag(-1, 1, -1). + g2_Rj = Rj.copy() + g2_Rj[:, [0, 2]] = -g2_Rj[:, [0, 2]] + Rijs = np.transpose(g2_Rj, axes=(0, 2, 1)) @ Ri[:, None] + + cls[idx, :, :, 2, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) + cls[idx, :, :, 2, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) + cls[idx + n_pairs, :, :, 2, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cls[idx + n_pairs, :, :, 2, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) + + # Compute relative rotations candidates Rij = Ri.T @ g3 @ Rj, + # where g3 = diag(-1, -1, 1). + g3_Rj = Rj.copy() + g3_Rj[:, 0:2] = -g3_Rj[:, 0:2] + Rijs = np.transpose(g3_Rj, axes=(0, 2, 1)) @ Ri[:, None] + + cls[idx, :, :, 3, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) + cls[idx, :, :, 3, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) + cls[idx + n_pairs, :, :, 3, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cls[idx + n_pairs, :, :, 3, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) + + idx += 1 + + return cls From 64c915a07e928f30f4c5a0ab36ceeae104f1e3f1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 6 Dec 2023 10:01:34 -0500 Subject: [PATCH 007/105] generate_commonline_indices --- src/aspire/abinitio/commonline_d2.py | 99 +++++++++++++++++++++------- 1 file changed, 75 insertions(+), 24 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index e5a95337ca..5eb6fafce1 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -63,7 +63,7 @@ def __init__( def estimate_rotations(self): """ - Estimate rotation matrices for molecules with C3 or C4 symmetry. + Estimate rotation matrices for molecules with D2 symmetry. :return: Array of rotation matrices, size n_imgx3x3. """ @@ -126,7 +126,7 @@ def generate_lookup_data(self): ) # Generate commmon line angles induced by all relative rotation candidates. - cl_angles_1 = self.generate_relative_rotations( + cl_angles1 = self.generate_commonline_angles( inplane_rotated_grid1, inplane_rotated_grid1, eq_idx1, @@ -134,7 +134,7 @@ def generate_lookup_data(self): eq_class1, eq_class1, ) - cl_angles_2 = self.generate_relative_rotations( + cl_angles2 = self.generate_commonline_angles( inplane_rotated_grid1, inplane_rotated_grid2, eq_idx1, @@ -143,7 +143,10 @@ def generate_lookup_data(self): eq_class2, ) - return cl_angles_1, cl_angles_2 + cl_ind_1 = self.generate_commonline_indices(cl_angles1) + cl_ind_2 = self.generate_commonline_indices(cl_angles2) + + return cl_angles1, cl_angles2 @staticmethod def saff_kuijlaars(N): @@ -277,13 +280,22 @@ def generate_inplane_rots(sphere_grid, d_theta): return inplane_rotated_grid @staticmethod - def generate_relative_rotations( + def generate_commonline_angles( Ris, Rjs, Ri_eq_idx, Rj_eq_idx, Ri_eq_class, Rj_eq_class ): """ + Compute commonline angles induced by the 4 sets of relative rotations + Rij = Ri.T @ g_m @ Rj, m = 0,1,2,3, where g_m is the identity and rotations + about the three axes of symmetry of a D2 symmetric molecule. + :param Ris: First set of candidate rotations. :param Rjs: Second set of candidate rotation. - :param Ri_eq_idx: + :param Ri_eq_idx: Equator index mask. + :param Rj_eq_idx: Equator index mask. + :param Ri_eq_class: Equator classification for Ris. + :param Rj_eq_class: Equator classification for Rjs. + + :return: Commonline angles induced by relative rotation candidates. """ n_rots_i = len(Ris) n_theta = Ris.shape[1] # Same for Rjs @@ -296,7 +308,7 @@ def generate_relative_rotations( n_pairs = np.sum(eq2eq_Rij_table) idx = 0 - cls = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) + cl_angles = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) for i in range(n_rots_i): unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] @@ -309,12 +321,16 @@ def generate_relative_rotations( Rijs = np.transpose(Rj, axes=(0, 2, 1)) @ Ri[:, None] # Common line indices induced by Rijs - cls[idx, :, :, 0, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) - cls[idx, :, :, 0, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) - cls[idx + n_pairs, :, :, 0, 0] = np.arctan2( + cl_angles[idx, :, :, 0, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, 0, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, 0, 0] = np.arctan2( Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] ) - cls[idx + n_pairs, :, :, 0, 1] = np.arctan2( + cl_angles[idx + n_pairs, :, :, 0, 1] = np.arctan2( -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] ) @@ -324,12 +340,16 @@ def generate_relative_rotations( g1_Rj[:, 1:3] = -g1_Rj[:, 1:3] Rijs = np.transpose(g1_Rj, axes=(0, 2, 1)) @ Ri[:, None] - cls[idx, :, :, 1, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) - cls[idx, :, :, 1, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) - cls[idx + n_pairs, :, :, 1, 0] = np.arctan2( + cl_angles[idx, :, :, 1, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, 1, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, 1, 0] = np.arctan2( Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] ) - cls[idx + n_pairs, :, :, 1, 1] = np.arctan2( + cl_angles[idx + n_pairs, :, :, 1, 1] = np.arctan2( -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] ) @@ -339,12 +359,16 @@ def generate_relative_rotations( g2_Rj[:, [0, 2]] = -g2_Rj[:, [0, 2]] Rijs = np.transpose(g2_Rj, axes=(0, 2, 1)) @ Ri[:, None] - cls[idx, :, :, 2, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) - cls[idx, :, :, 2, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) - cls[idx + n_pairs, :, :, 2, 0] = np.arctan2( + cl_angles[idx, :, :, 2, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, 2, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, 2, 0] = np.arctan2( Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] ) - cls[idx + n_pairs, :, :, 2, 1] = np.arctan2( + cl_angles[idx + n_pairs, :, :, 2, 1] = np.arctan2( -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] ) @@ -354,15 +378,42 @@ def generate_relative_rotations( g3_Rj[:, 0:2] = -g3_Rj[:, 0:2] Rijs = np.transpose(g3_Rj, axes=(0, 2, 1)) @ Ri[:, None] - cls[idx, :, :, 3, 0] = np.arctan2(Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1]) - cls[idx, :, :, 3, 1] = np.arctan2(-Rijs[:, :, 0, 2], Rijs[:, :, 1, 2]) - cls[idx + n_pairs, :, :, 3, 0] = np.arctan2( + cl_angles[idx, :, :, 3, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, 3, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, 3, 0] = np.arctan2( Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] ) - cls[idx + n_pairs, :, :, 3, 1] = np.arctan2( + cl_angles[idx + n_pairs, :, :, 3, 1] = np.arctan2( -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] ) idx += 1 - return cls + return cl_angles + + @staticmethod + def generate_commonline_indices(cl_angles): + # Make all angles non-negative and convert to degrees. + cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) + cl_angles = cl_angles * 180 / np.pi + + # Flatten the stack + og_shape = cl_angles.shape + cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) + + # Fourier ray index + cl_ind_j = np.round(cl_angles[:, 0]).astype("int") % 360 + cl_ind_i = np.round(cl_angles[:, 1]).astype("int") % 360 + + # Restrict Rj in-plane coordinates to <180 degrees. + is_geq_than_pi = cl_ind_j >= 180 + cl_ind_j[is_geq_than_pi] = cl_ind_j[is_geq_than_pi] - 180 + cl_ind_i[is_geq_than_pi] = (cl_ind_i[is_geq_than_pi] + 180) % 360 + + # Convert to linear indices in 360*180 correlation matrix + cl_ind = np.ravel_multi_index((cl_ind_i, cl_ind_j), dims=(360, 180)) + return cl_ind From ebad3af0985ac14e78eb180f42578cd723404031 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 11 Dec 2023 15:24:00 -0500 Subject: [PATCH 008/105] generate_gs. Partial self-commonline lookup. --- src/aspire/abinitio/commonline_d2.py | 98 +++++++++++++++++++--------- 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5eb6fafce1..2c3597c694 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -32,7 +32,7 @@ def __init__( seed=None, ): """ - Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. + Initialize object for estimating 3D orientations for molecules with D2 symmetry. :param src: The source object of 2D denoised or class-averaged images with metadata :param n_rad: The number of points in the radial direction @@ -60,6 +60,7 @@ def __init__( self.inplane_res = inplane_res self.eq_min_dist = eq_min_dist self.seed = seed + self._generate_gs() def estimate_rotations(self): """ @@ -67,7 +68,12 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ - pass + self.generate_lookup_data() + self.generate_scl_lookup_data( + self.inplane_rotated_grid1, + self.eq_idx1, + self.eq_class1, + ) def generate_lookup_data(self): """ @@ -110,43 +116,62 @@ def generate_lookup_data(self): # two common lines. # Remove top views from sphere grids and update equator indices and classes. - sphere_grid1 = sphere_grid1[eq_class1 < 4] - sphere_grid2 = sphere_grid2[eq_class2 < 4] - eq_idx1 = eq_idx1[eq_class1 < 4] - eq_idx2 = eq_idx2[eq_class2 < 4] - eq_class1 = eq_class1[eq_class1 < 4] - eq_class2 = eq_class2[eq_class2 < 4] + self.sphere_grid1 = sphere_grid1[eq_class1 < 4] + self.sphere_grid2 = sphere_grid2[eq_class2 < 4] + self.eq_idx1 = eq_idx1[eq_class1 < 4] + self.eq_idx2 = eq_idx2[eq_class2 < 4] + self.eq_class1 = eq_class1[eq_class1 < 4] + self.eq_class2 = eq_class2[eq_class2 < 4] # Generate in-plane rotations for each grid point on the sphere. - inplane_rotated_grid1 = self.generate_inplane_rots( - sphere_grid1, self.inplane_res + self.inplane_rotated_grid1 = self.generate_inplane_rots( + self.sphere_grid1, self.inplane_res ) - inplane_rotated_grid2 = self.generate_inplane_rots( - sphere_grid2, self.inplane_res + self.inplane_rotated_grid2 = self.generate_inplane_rots( + self.sphere_grid2, self.inplane_res ) - # Generate commmon line angles induced by all relative rotation candidates. - cl_angles1 = self.generate_commonline_angles( - inplane_rotated_grid1, - inplane_rotated_grid1, - eq_idx1, - eq_idx1, - eq_class1, - eq_class1, + # Generate commmonline angles induced by all relative rotation candidates. + self.cl_angles1 = self.generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid1, + self.eq_idx1, + self.eq_idx1, + self.eq_class1, + self.eq_class1, ) - cl_angles2 = self.generate_commonline_angles( - inplane_rotated_grid1, - inplane_rotated_grid2, - eq_idx1, - eq_idx2, - eq_class1, - eq_class2, + self.cl_angles2 = self.generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid2, + self.eq_idx1, + self.eq_idx2, + self.eq_class1, + self.eq_class2, ) - cl_ind_1 = self.generate_commonline_indices(cl_angles1) - cl_ind_2 = self.generate_commonline_indices(cl_angles2) + # Generate commonline indices. + self.cl_ind_1 = self.generate_commonline_indices(self.cl_angles1) + self.cl_ind_2 = self.generate_commonline_indices(self.cl_angles2) + + def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): + """ + Generate lookup data for self-commonlines. - return cl_angles1, cl_angles2 + :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). + :param eq_idx: Equator index mask for Ris. + :param eq_class: Equator classification for Ris. + """ + # For each candidate rotation Ri we generate the set of 3 self-commonlines. + scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) + n_rots = len(Ris) + for i in range(n_rots): + Ri = Ris[i] + for j, g in enumerate(self.gs[1:]): + g_Ri = g * Ri + Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri + + scl_angles[i, :, j, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) + scl_angles[i, :, j, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) @staticmethod def saff_kuijlaars(N): @@ -417,3 +442,16 @@ def generate_commonline_indices(cl_angles): # Convert to linear indices in 360*180 correlation matrix cl_ind = np.ravel_multi_index((cl_ind_i, cl_ind_j), dims=(360, 180)) return cl_ind + + def _generate_gs(self): + """ + Generate analogue to D2 rotation matrices, such that element-wise + multiplication, `*`, by gs is equivalent to matrix multiplication, + `@`, by a correspopnding rotation matrix. + """ + gs = np.ones((4, 3, 3), dtype=self.dtype) + gs[1, 1:3] = -gs[1, 1:3] + gs[2, [0, 2]] = -gs[2, [0, 2]] + gs[3, 0:2] = -gs[3, 0:2] + + self.gs = gs From 88b0347bed60bc7ac8fadab9834122d940eaa073 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 Dec 2023 09:11:29 -0500 Subject: [PATCH 009/105] More self-commonline stuff. --- src/aspire/abinitio/commonline_d2.py | 153 +++++++++++++-------------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 2c3597c694..006022ea7c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -166,12 +166,65 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): n_rots = len(Ris) for i in range(n_rots): Ri = Ris[i] - for j, g in enumerate(self.gs[1:]): + for k, g in enumerate(self.gs[1:]): g_Ri = g * Ri Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri - scl_angles[i, :, j, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) - scl_angles[i, :, j, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) + scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) + scl_angles[i, :, k, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) + + # Prepare self commonline coordinates. + scl_angles = scl_angles % (2 * np.pi) + + # Deal with non top view equators + # A non-TV equator has only one self common line. However, we clasify an + # equator as an image whose projection direction is at radial distance < + # eq_filter_angle from the great circle perpendicual to a symmetry axis, + # and not strcitly zero distance. Thus in most cases we get 2 common lines + # differing by a small difference in degrees. Actually the calculation above + # gives us two NEARLY antipodal lines, so we first flip one of them by + # adding 180 degrees to it. Then we aggregate all the rays within the range + # between these two resulting lines to compute the score of this self common + # line for this candidate. The scoring part is done in the ML function itself. + # Furthermore, the line perpendicular to the self common line, though not + # really a self common line, has the property that all its values are real + # and both halves of the line (rays differing by pi, emanating from the + # origin) have the same values, and so it 'beahves like' a self common + # line which we also register here and exploit in the ML function. + # We put the 'real' self common line at 2 first coordinates, the + # candidate for perpendicular line is in 3rd coordinate. + + # If this is a self common line with respect to x-equator then the actual self + # common line(s) is given by the self relative rotations given by the y and z + # rotation (by 180 degrees) group members, i.e. Ri^TgyRj and Ri^TgzRj + scl_angles[eq_class == 1] = scl_angles[eq_class == 1][:, :, [1, 2, 0]] + scl_angles[eq_class == 1, :, 0] = scl_angles[eq_class == 1][:, :, 0, [1, 0]] + + # If this is a self common line with respect to y-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and z + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgzRj + scl_angles[eq_class == 2] = scl_angles[eq_class == 2][:, :, [0, 2, 1]] + scl_angles[eq_class == 2, :, 0] = scl_angles[eq_class == 2][:, :, 0, [1, 0]] + + # If this is a self common line with respect to z-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and y + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgyRj + # No need to rearrange entries, the "real" common lines are already in + # indices 1 and 2, but flip one common line to antipodal. + scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] + + # TODO: This section is silly! Clean up! + # Make sure angle range is <= 180 degrees. + p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] + p1 = p1[:, :, 0] & p1[:, :, 1] + p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi + p2 = p2[:, :, 0] | p2[:, :, 1] + p = p1 | p2 + + scl_angles[eq_class > 0] = ( + scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] + + scl_angles[eq_class > 0] * ~p[:, :, None, None] + ) @staticmethod def saff_kuijlaars(N): @@ -304,9 +357,8 @@ def generate_inplane_rots(sphere_grid, d_theta): return inplane_rotated_grid - @staticmethod def generate_commonline_angles( - Ris, Rjs, Ri_eq_idx, Rj_eq_idx, Ri_eq_class, Rj_eq_class + self, Ris, Rjs, Ri_eq_idx, Rj_eq_idx, Ri_eq_class, Rj_eq_class ): """ Compute commonline angles induced by the 4 sets of relative rotations @@ -341,80 +393,25 @@ def generate_commonline_angles( continue Ri = Ris[i] for j in unique_pairs_i: - # Compute relative rotations candidates Rij = Ri.T @ Rj Rj = Rjs[j, : (n_theta // 2)] - Rijs = np.transpose(Rj, axes=(0, 2, 1)) @ Ri[:, None] - - # Common line indices induced by Rijs - cl_angles[idx, :, :, 0, 0] = np.arctan2( - Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] - ) - cl_angles[idx, :, :, 0, 1] = np.arctan2( - -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 0, 0] = np.arctan2( - Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 0, 1] = np.arctan2( - -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] - ) - - # Compute relative rotations candidates Rij = Ri.T @ g1 @ Rj, - # where g1 = diag(1, -1, -1). - g1_Rj = Rj.copy() - g1_Rj[:, 1:3] = -g1_Rj[:, 1:3] - Rijs = np.transpose(g1_Rj, axes=(0, 2, 1)) @ Ri[:, None] - - cl_angles[idx, :, :, 1, 0] = np.arctan2( - Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] - ) - cl_angles[idx, :, :, 1, 1] = np.arctan2( - -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 1, 0] = np.arctan2( - Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 1, 1] = np.arctan2( - -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] - ) - - # Compute relative rotations candidates Rij = Ri.T @ g2 @ Rj, - # where g2 = diag(-1, 1, -1). - g2_Rj = Rj.copy() - g2_Rj[:, [0, 2]] = -g2_Rj[:, [0, 2]] - Rijs = np.transpose(g2_Rj, axes=(0, 2, 1)) @ Ri[:, None] - - cl_angles[idx, :, :, 2, 0] = np.arctan2( - Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] - ) - cl_angles[idx, :, :, 2, 1] = np.arctan2( - -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 2, 0] = np.arctan2( - Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 2, 1] = np.arctan2( - -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] - ) - - # Compute relative rotations candidates Rij = Ri.T @ g3 @ Rj, - # where g3 = diag(-1, -1, 1). - g3_Rj = Rj.copy() - g3_Rj[:, 0:2] = -g3_Rj[:, 0:2] - Rijs = np.transpose(g3_Rj, axes=(0, 2, 1)) @ Ri[:, None] - - cl_angles[idx, :, :, 3, 0] = np.arctan2( - Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] - ) - cl_angles[idx, :, :, 3, 1] = np.arctan2( - -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 3, 0] = np.arctan2( - Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] - ) - cl_angles[idx + n_pairs, :, :, 3, 1] = np.arctan2( - -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] - ) + for k, g in enumerate(self.gs): + # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj + g_Rj = g * Rj + Rijs = np.transpose(g_Rj, axes=(0, 2, 1)) @ Ri[:, None] + + # Common line indices induced by Rijs + cl_angles[idx, :, :, k, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, k, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, k, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, k, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) idx += 1 From d05dd98c00d73f2f53cc708af261ceb2e36ad184 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 26 Feb 2024 15:16:29 -0500 Subject: [PATCH 010/105] circ_seq. more scl_lookup_data. --- src/aspire/abinitio/commonline_d2.py | 120 ++++++++++++++++++++++----- 1 file changed, 98 insertions(+), 22 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 006022ea7c..b728c5bd89 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -35,8 +35,8 @@ def __init__( Initialize object for estimating 3D orientations for molecules with D2 symmetry. :param src: The source object of 2D denoised or class-averaged images with metadata - :param n_rad: The number of points in the radial direction - :param n_theta: The number of points in the theta direction + :param n_rad: The number of points in the radial direction of Fourier image. + :param n_theta: The number of points in the theta direction of Fourier image. :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. :param grid_res: Number of sampling points on sphere for projetion directions. @@ -132,7 +132,7 @@ def generate_lookup_data(self): ) # Generate commmonline angles induced by all relative rotation candidates. - self.cl_angles1 = self.generate_commonline_angles( + cl_angles1 = self.generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid1, self.eq_idx1, @@ -140,7 +140,7 @@ def generate_lookup_data(self): self.eq_class1, self.eq_class1, ) - self.cl_angles2 = self.generate_commonline_angles( + cl_angles2 = self.generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid2, self.eq_idx1, @@ -150,8 +150,12 @@ def generate_lookup_data(self): ) # Generate commonline indices. - self.cl_ind_1 = self.generate_commonline_indices(self.cl_angles1) - self.cl_ind_2 = self.generate_commonline_indices(self.cl_angles2) + self.cl_ind_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) + self.cl_ind_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) + + self.generate_scl_lookup_data( + self.inplane_rotated_grid1, self.eq_idx1, self.eq_class1 + ) def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): """ @@ -161,6 +165,8 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): :param eq_idx: Equator index mask for Ris. :param eq_class: Equator classification for Ris. """ + L = 360 # TODO: Maybe this should be self.n_theta + # For each candidate rotation Ri we generate the set of 3 self-commonlines. scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) n_rots = len(Ris) @@ -179,8 +185,8 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): # Deal with non top view equators # A non-TV equator has only one self common line. However, we clasify an # equator as an image whose projection direction is at radial distance < - # eq_filter_angle from the great circle perpendicual to a symmetry axis, - # and not strcitly zero distance. Thus in most cases we get 2 common lines + # eq_filter_angle from the great circle perpendicural to a symmetry axis, + # and not strictly zero distance. Thus in most cases we get 2 common lines # differing by a small difference in degrees. Actually the calculation above # gives us two NEARLY antipodal lines, so we first flip one of them by # adding 180 degrees to it. Then we aggregate all the rays within the range @@ -189,7 +195,7 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): # Furthermore, the line perpendicular to the self common line, though not # really a self common line, has the property that all its values are real # and both halves of the line (rays differing by pi, emanating from the - # origin) have the same values, and so it 'beahves like' a self common + # origin) have the same values, and so it 'behaves like' a self common # line which we also register here and exploit in the ML function. # We put the 'real' self common line at 2 first coordinates, the # candidate for perpendicular line is in 3rd coordinate. @@ -213,19 +219,81 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): # indices 1 and 2, but flip one common line to antipodal. scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] - # TODO: This section is silly! Clean up! + # TODO: Maybe a cleaner way to do this. # Make sure angle range is <= 180 degrees. + # p1 marks "equator" self-commonlines where both entries of the first + # scl are greater than both entries of the second scl. p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] p1 = p1[:, :, 0] & p1[:, :, 1] + # p2 marks "equator" self-commonlines where the angle range between the + # first and second sets of self-commonlines is greater than 180. p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi p2 = p2[:, :, 0] | p2[:, :, 1] p = p1 | p2 + # Swap entries satisfying either of the above conditions. scl_angles[eq_class > 0] = ( scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] + scl_angles[eq_class > 0] * ~p[:, :, None, None] ) + # Convert angles from radians to degrees. + scl_angles = np.round(scl_angles * 180 / np.pi) % L + import pdb + + pdb.set_trace() + # Create candidate common line linear indices lists for equators. + # As indicated above for equator candidate, for each self common line we + # don't get a single coordinate but a range of them. Here we register a + # list of coordinates for each such self common line candidate. + non_top_view_eq_idx = np.where(eq_class > 0)[0] + n_eq = len(non_top_view_eq_idx) + n_inplane_rots = Ris.shape[1] + count_eq = 0 + + # eq_lin_idx_lists[i,j,1] registers a list of linear indices of the j'th + # in-plane rotation of the range for the (only) self common line of the i'th + # candidate. eq_lin_idx_lists[i,j,2] registers the actual (integer) angle + # of the self common line in the 2D Fourier space. Note that we need only + # one number since each self common line has radial coordinates of the form + # (theta, theta+180). + eq_lin_idx_lists = [] + for i in list(non_top_view_eq_idx): + i_list = [] + for j in range(n_inplane_rots): + idx1 = self.circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) + idx2 = self.circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + + # Adjust so idx2 is in [0, 180) range. + idx2[idx2 >= 180] = (idx2[idx2 >= 180] - L // 2) % (L // 2) + idx1[idx2 >= 180] = (idx1[idx2 >= 180] + L // 2) % L + print(i, j, idx1, idx2) + # register indices in list. + i_list.append(np.ravel_multi_index((idx1, idx2), (L, L // 2))) + i_list.append(idx2) + + eq_lin_idx_lists.append(i_list) + + @staticmethod + def circ_seq(n1, n2, L): + """ + Make a circular sequence of integers between n1 and n2 modulo L. + + :param n1: First integer in sequence. + :param n2: Last integer in sequence. + :param L: Modulus of values in sequence. + :return: Circular sequence modulo L. + """ + if n2 < n1: + n2 += L + if n1 == n2: + return np.array(n1).astype(int) + + seq = np.arange(n1, n2) % L + seq[abs(seq) < 1e-8] = L + + return seq.astype(int) + @staticmethod def saff_kuijlaars(N): """ @@ -333,7 +401,7 @@ def generate_inplane_rots(sphere_grid, d_theta): dtype = sphere_grid.dtype # Generate one rotation for each point on the sphere. n_rots = len(sphere_grid) - Ri2 = np.column_stack((-sphere_grid[:, 2], sphere_grid[:, 1], np.zeros(n_rots))) + Ri2 = np.column_stack((-sphere_grid[:, 1], sphere_grid[:, 0], np.zeros(n_rots))) Ri2 /= np.linalg.norm(Ri2, axis=1)[:, None] Ri1 = np.cross(Ri2, sphere_grid) Ri1 /= np.linalg.norm(Ri1, axis=1)[:, None] @@ -345,8 +413,9 @@ def generate_inplane_rots(sphere_grid, d_theta): # Generate in-plane rotations. d_theta *= np.pi / 180 + # TODO: Negative signs to match matlab. inplane_rots = Rotation.about_axis( - "z", np.arange(0, 2 * np.pi, d_theta), dtype=dtype + "z", np.arange(0, -2 * np.pi, -d_theta), dtype=dtype ).matrices n_inplane_rots = len(inplane_rots) @@ -375,13 +444,13 @@ def generate_commonline_angles( :return: Commonline angles induced by relative rotation candidates. """ n_rots_i = len(Ris) - n_theta = Ris.shape[1] # Same for Rjs + n_theta = Ris.shape[1] # Same for Rjs, TODO: Don't call this n_theta # Generate upper triangular table of indicators of all pairs which are not # equators with respect to the same symmetry axis (named unique_pairs). eq_table = np.outer(Ri_eq_idx, Rj_eq_idx) in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 - eq2eq_Rij_table = np.triu(~(eq_table * in_same_class)) + eq2eq_Rij_table = np.triu(~(eq_table * in_same_class), 1) n_pairs = np.sum(eq2eq_Rij_table) idx = 0 @@ -419,6 +488,8 @@ def generate_commonline_angles( @staticmethod def generate_commonline_indices(cl_angles): + # TODO: This is not accounting for n_theta other than 360! + # Make all angles non-negative and convert to degrees. cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) cl_angles = cl_angles * 180 / np.pi @@ -428,17 +499,22 @@ def generate_commonline_indices(cl_angles): cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) # Fourier ray index - cl_ind_j = np.round(cl_angles[:, 0]).astype("int") % 360 - cl_ind_i = np.round(cl_angles[:, 1]).astype("int") % 360 + row_sub = np.round(cl_angles[:, 0]).astype("int") % 360 + col_sub = np.round(cl_angles[:, 1]).astype("int") % 360 # Restrict Rj in-plane coordinates to <180 degrees. - is_geq_than_pi = cl_ind_j >= 180 - cl_ind_j[is_geq_than_pi] = cl_ind_j[is_geq_than_pi] - 180 - cl_ind_i[is_geq_than_pi] = (cl_ind_i[is_geq_than_pi] + 180) % 360 + is_geq_than_pi = col_sub >= 180 + col_sub[is_geq_than_pi] = col_sub[is_geq_than_pi] - 180 + row_sub[is_geq_than_pi] = (row_sub[is_geq_than_pi] + 180) % 360 + + # Convert to linear indices in 360*180 correlation matrix (same as cls_lookup in matlab) + cl_ind = np.ravel_multi_index((row_sub, col_sub), dims=(360, 180)) + + # Reshape cl_angles (to match matlab `cls`) + cl_angles = cl_angles.reshape(og_shape) - # Convert to linear indices in 360*180 correlation matrix - cl_ind = np.ravel_multi_index((cl_ind_i, cl_ind_j), dims=(360, 180)) - return cl_ind + # Return as integer indices. + return cl_ind, np.rint(cl_angles).astype(int) def _generate_gs(self): """ From ee70cfa5e6b8c783e05ef741945606705013bb64 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 29 Feb 2024 15:58:19 -0500 Subject: [PATCH 011/105] Finish scl_lookup --- src/aspire/abinitio/commonline_d2.py | 78 +++++++++++++++++----------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index b728c5bd89..a702902928 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -69,11 +69,7 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ self.generate_lookup_data() - self.generate_scl_lookup_data( - self.inplane_rotated_grid1, - self.eq_idx1, - self.eq_class1, - ) + self.generate_scl_lookup_data() def generate_lookup_data(self): """ @@ -153,13 +149,26 @@ def generate_lookup_data(self): self.cl_ind_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) self.cl_ind_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) - self.generate_scl_lookup_data( - self.inplane_rotated_grid1, self.eq_idx1, self.eq_class1 + def generate_scl_lookup_data(self): + """ + Generate lookup data for self-commonlines. + + :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). + :param eq_idx: Equator index mask for Ris. + :param eq_class: Equator classification for Ris. + """ + self.scl_angles1 = self.generate_scl_angles( + self.inplane_rotated_grid1, + self.eq_idx1, + self.eq_class1, + ) + self.scl_ind_1, self.scl_eq_lin_idx_lists_1 = self.generate_scl_indices( + self.scl_angles1, self.eq_class1 ) - def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): + def generate_scl_angles(self, Ris, eq_idx, eq_class): """ - Generate lookup data for self-commonlines. + Generate self-commonline angles. :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). :param eq_idx: Equator index mask for Ris. @@ -172,7 +181,8 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): n_rots = len(Ris) for i in range(n_rots): Ri = Ris[i] - for k, g in enumerate(self.gs[1:]): + # TODO: Reversing self.gs here to match matlab. Should use as is. + for k, g in enumerate(self.gs[::-1][:3]): g_Ri = g * Ri Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri @@ -237,42 +247,49 @@ def generate_scl_lookup_data(self, Ris, eq_idx, eq_class): + scl_angles[eq_class > 0] * ~p[:, :, None, None] ) - # Convert angles from radians to degrees. + # Convert angles from radians to degrees (indices). scl_angles = np.round(scl_angles * 180 / np.pi) % L - import pdb - pdb.set_trace() + return scl_angles + + def generate_scl_indices(self, scl_angles, eq_class): + L = 360 + # Create candidate common line linear indices lists for equators. # As indicated above for equator candidate, for each self common line we # don't get a single coordinate but a range of them. Here we register a # list of coordinates for each such self common line candidate. non_top_view_eq_idx = np.where(eq_class > 0)[0] n_eq = len(non_top_view_eq_idx) - n_inplane_rots = Ris.shape[1] + n_inplane_rots = scl_angles.shape[1] count_eq = 0 - # eq_lin_idx_lists[i,j,1] registers a list of linear indices of the j'th + # eq_lin_idx_lists[1,i,j] registers a list of linear indices of the j'th # in-plane rotation of the range for the (only) self common line of the i'th - # candidate. eq_lin_idx_lists[i,j,2] registers the actual (integer) angle + # candidate. eq_lin_idx_lists[2,i,j] registers the actual (integer) angle # of the self common line in the 2D Fourier space. Note that we need only # one number since each self common line has radial coordinates of the form # (theta, theta+180). - eq_lin_idx_lists = [] - for i in list(non_top_view_eq_idx): - i_list = [] + eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) + for i in non_top_view_eq_idx.tolist(): for j in range(n_inplane_rots): idx1 = self.circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) idx2 = self.circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) # Adjust so idx2 is in [0, 180) range. - idx2[idx2 >= 180] = (idx2[idx2 >= 180] - L // 2) % (L // 2) idx1[idx2 >= 180] = (idx1[idx2 >= 180] + L // 2) % L - print(i, j, idx1, idx2) + idx2[idx2 >= 180] = (idx2[idx2 >= 180] - L // 2) % (L // 2) + # register indices in list. - i_list.append(np.ravel_multi_index((idx1, idx2), (L, L // 2))) - i_list.append(idx2) + eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( + (idx1, idx2), (L, L // 2) + ) + eq_lin_idx_lists[1, count_eq, j] = idx2 + count_eq += 1 - eq_lin_idx_lists.append(i_list) + scl_indices, _ = self.generate_commonline_indices(scl_angles) + + return scl_indices, eq_lin_idx_lists @staticmethod def circ_seq(n1, n2, L): @@ -289,10 +306,9 @@ def circ_seq(n1, n2, L): if n1 == n2: return np.array(n1).astype(int) - seq = np.arange(n1, n2) % L - seq[abs(seq) < 1e-8] = L + seq = np.arange(n1, n2 + 1).astype(int) % L - return seq.astype(int) + return seq @staticmethod def saff_kuijlaars(N): @@ -484,16 +500,16 @@ def generate_commonline_angles( idx += 1 + # Make all angles non-negative and convert to degrees. + cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) + cl_angles = cl_angles * 180 / np.pi + return cl_angles @staticmethod def generate_commonline_indices(cl_angles): # TODO: This is not accounting for n_theta other than 360! - # Make all angles non-negative and convert to degrees. - cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) - cl_angles = cl_angles * 180 / np.pi - # Flatten the stack og_shape = cl_angles.shape cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) From 80a1ef14158ab5b9c41625fb358e205b834006a6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 8 Mar 2024 16:02:51 -0500 Subject: [PATCH 012/105] all_eq_measures and partial compute_scl_scores --- src/aspire/abinitio/commonline_d2.py | 121 ++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index a702902928..dff513e923 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1,9 +1,11 @@ import logging import numpy as np +from numpy.linalg import norm from aspire.abinitio import CLOrient3D -from aspire.utils import Rotation +from aspire.operators import PolarFT +from aspire.utils import Rotation, trange logger = logging.getLogger(__name__) @@ -70,6 +72,114 @@ def estimate_rotations(self): """ self.generate_lookup_data() self.generate_scl_lookup_data() + self.compute_scl_scores() + + def compute_scl_scores(self): + pf = self.pf + n_img = self.n_img + L = self.src.L + n_theta = self.n_theta + max_shift_1d = self.max_shift + shift_step = self.shift_step + + # Compute the correlation over all shifts. + # Generate Shifts. + r_max = pf.shape[-1] + shifts, shift_phases, _ = self._generate_shift_phase_and_filter( + r_max, max_shift_1d, shift_step + ) + n_shifts = len(shifts) + + # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of + # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). + pf_full = PolarFT.half_to_full(pf) + + for i in trange(n_img): + pf_i = pf[i] + pf_full_i = pf_full[i] + + # Generate shifted versions of images. + pf_i_shifted = np.array( + [pf_i * shift_phase for shift_phase in shift_phases] + ) + pf_i_shifted = np.reshape(pf_i_shifted, (n_shifts * n_theta // 2, r_max)) + + # # Normalize each ray. + pf_full_i /= norm(pf_full_i, axis=1)[..., np.newaxis] + pf_i_shifted /= norm(pf_i_shifted, axis=1)[..., np.newaxis] + + # Compute max correlation over all shifts. + corrs = np.real(pf_i_shifted @ pf_full_i.T) + corrs = np.reshape(corrs, (n_shifts, n_theta // 2, n_theta)) + corrs = np.max(corrs, axis=0) + + # Map correlations to probabilities (in the spirit of Maximum Likelihood). + corrs = 0.5 * (corrs + 1) + + # Compute equator measures. + eq_measures = self.all_eq_measures(corrs) + + def all_eq_measures(self, corrs): + """ + Compute a measure of how much an image from data is close to be an equator. + """ + # First compute the eq measure (corrs(scl-k,scl+k) for k=1:90) + # An eqautor image of a D2 molecule has the following property: If t_i is + # the angle of one of the rays of the self common line then all the pairs of + # rays of the form (t_i-k,t_i+k) for k=1:90 are identical. For each t_i we + # average over correlations between the lines (t_i-k,t_i+k) for k=1:90 + # to measure the likelihood that the image is an equator and the ray (line) + # with angle t_i is a self common line. + # (This first loop can be done once outside this function and then pass + # idx as an argument). + idx = np.zeros((180, 90, 2)) + idx_1 = np.mod(np.vstack((-np.arange(1, 91), np.arange(1, 91))), 360) + idx[0, :, :] = idx_1.T + for k in range(1, 180): + idx[k, :, :] = np.mod(idx_1.T + k, 360) + idx = np.mod(idx, 360) + + idx_1 = idx[:, :, 0].flatten() + idx_2 = idx[:, :, 1].flatten() + + # Make all Ri coordinates < 180 and compute linear indices for corrrelations + bigger_than_180 = idx_1 >= 180 + idx_1[bigger_than_180] = idx_1[bigger_than_180] - 180 + idx_2[bigger_than_180] = (idx_2[bigger_than_180] + 180) % 360 + + # Compute correlations. + eq_corrs = corrs[idx_1.astype(int), idx_2.astype(int)] + eq_corrs = eq_corrs.reshape(180, 90) + corrs_mean = np.mean(eq_corrs, axis=1) + + # Now compute correlations for normals to scls. + # An eqautor image of a D2 molecule has the additional following property: + # The normal line to a self common line in 2D Fourier plane is real valued + # and both of its rays have identical values. We use the correlation + # between one Fourier ray of the normal to a self common line candidate t_i + # with its anti-podal as an additional way to measure if the image is an + # equator and t_i+0.5*pi is the normal to its self common line. + r = 2 + + normal_2_scl_idx = np.zeros((180, 2 * r + 1)) + normal_2_scl_idx_1 = np.mod(180 - np.arange(90 - r, 90 + r + 1), 360) + normal_2_scl_idx[0, :] = normal_2_scl_idx_1 + for k in range(1, 180): + normal_2_scl_idx[k, :] = np.mod(normal_2_scl_idx_1 + k, 360) + + # Make all Ri coordinates <=180 and compute linear indices for corrrelations + bigger_than_180 = normal_2_scl_idx >= 180 + normal_2_scl_idx[bigger_than_180] = normal_2_scl_idx[bigger_than_180] - 180 + + # Compute correlations for normals. + normal_2_scl_idx = normal_2_scl_idx.flatten() + normal_corrs = corrs[ + normal_2_scl_idx.astype(int), normal_2_scl_idx.astype(int) + 180 + ] + normal_corrs = normal_corrs.reshape(180, 2 * r + 1) + normal_corrs_max = np.max(normal_corrs, axis=1) + + return corrs_mean * normal_corrs_max def generate_lookup_data(self): """ @@ -162,9 +272,18 @@ def generate_scl_lookup_data(self): self.eq_idx1, self.eq_class1, ) + self.scl_angles2 = self.generate_scl_angles( + self.inplane_rotated_grid2, + self.eq_idx2, + self.eq_class2, + ) + self.scl_ind_1, self.scl_eq_lin_idx_lists_1 = self.generate_scl_indices( self.scl_angles1, self.eq_class1 ) + self.scl_ind_2, self.scl_eq_lin_idx_lists_2 = self.generate_scl_indices( + self.scl_angles2, self.eq_class2 + ) def generate_scl_angles(self, Ris, eq_idx, eq_class): """ From 52d8974cf3d2bf01fff960a8774da0b64ffb7402 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 14 Mar 2024 09:19:57 -0400 Subject: [PATCH 013/105] compute_scl_scores. Enforce aspire-python cl correlation convention, ie. (180, 360) shape. --- src/aspire/abinitio/commonline_d2.py | 125 ++++++++++++++++++++++----- 1 file changed, 102 insertions(+), 23 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index dff513e923..2dfcfe0918 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -60,6 +60,7 @@ def __init__( self.grid_res = grid_res self.inplane_res = inplane_res + self.n_inplane_rots = int(360 / self.inplane_res) self.eq_min_dist = eq_min_dist self.seed = seed self._generate_gs() @@ -75,12 +76,16 @@ def estimate_rotations(self): self.compute_scl_scores() def compute_scl_scores(self): + """ + Compute correlations for self-commonline candidates. + """ pf = self.pf n_img = self.n_img - L = self.src.L n_theta = self.n_theta max_shift_1d = self.max_shift shift_step = self.shift_step + n_eq = len(self.non_tv_eq_idx) + n_inplane = self.n_inplane_rots # Compute the correlation over all shifts. # Generate Shifts. @@ -94,6 +99,19 @@ def compute_scl_scores(self): # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). pf_full = PolarFT.half_to_full(pf) + # Run ML in parallel + scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) + M = len(scl_matrix) // 3 + corrs_out = np.zeros((n_img, M), dtype=self.dtype) + scl_idx = scl_matrix.reshape(M, 3) + + # Get non-equator indices to use with corrs matrix. + non_eq_lin_idx = self.non_eq_idx.flatten() + n_non_eq = len(non_eq_lin_idx) + non_eq_idx = np.unravel_index( + scl_idx[non_eq_lin_idx].flatten(), (n_theta // 2, n_theta) + ) + for i in trange(n_img): pf_i = pf[i] pf_full_i = pf_full[i] @@ -109,9 +127,9 @@ def compute_scl_scores(self): pf_i_shifted /= norm(pf_i_shifted, axis=1)[..., np.newaxis] # Compute max correlation over all shifts. - corrs = np.real(pf_i_shifted @ pf_full_i.T) - corrs = np.reshape(corrs, (n_shifts, n_theta // 2, n_theta)) - corrs = np.max(corrs, axis=0) + corrs = np.real(pf_i_shifted @ np.conj(pf_full_i).T) + corrs = np.reshape(corrs, (n_theta // 2, n_shifts, n_theta)) + corrs = np.max(corrs, axis=1) # Map correlations to probabilities (in the spirit of Maximum Likelihood). corrs = 0.5 * (corrs + 1) @@ -119,6 +137,31 @@ def compute_scl_scores(self): # Compute equator measures. eq_measures = self.all_eq_measures(corrs) + # Handle the cases: Non-equator, Non-top-view equator, and Top view images. + # 1. Non-equators: just take product of probabilities. + prod_corrs = np.prod(corrs[non_eq_idx].reshape(n_non_eq, 3), axis=1) + corrs_out[i, non_eq_lin_idx] = prod_corrs + + # 2. Non-topview equators: adjust scores by eq_measures + for eq_idx in range(n_eq): + for j in range(n_inplane): + # Take the correlations for the self common line candidate of the + # "equator rotation" `eq_idx` with respect to image i, and + # multiply by all scores from the function eq_measures (see + # documentation inside the function ). Then take maximum over + # all the scores. + scl_idx_list = np.unravel_index( + self.scl_idx_lists[0, eq_idx, j], (n_theta // 2, n_theta) + ) + true_scls_corrs = corrs[scl_idx_list] + scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] + eq_measures_j = eq_measures[scls_cand_idx] + measures_agg = true_scls_corrs * eq_measures_j + k = self.non_tv_eq_idx[eq_idx] + corrs_out[i, k * n_inplane + j] = np.max(measures_agg) + + self.scls_scores = corrs_out + def all_eq_measures(self, corrs): """ Compute a measure of how much an image from data is close to be an equator. @@ -226,6 +269,7 @@ def generate_lookup_data(self): self.sphere_grid2 = sphere_grid2[eq_class2 < 4] self.eq_idx1 = eq_idx1[eq_class1 < 4] self.eq_idx2 = eq_idx2[eq_class2 < 4] + self.eq_idx = np.concatenate((self.eq_idx1, self.eq_idx2)) self.eq_class1 = eq_class1[eq_class1 < 4] self.eq_class2 = eq_class2[eq_class2 < 4] @@ -256,17 +300,14 @@ def generate_lookup_data(self): ) # Generate commonline indices. - self.cl_ind_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) - self.cl_ind_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) + self.cl_idx_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) + self.cl_idx_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) def generate_scl_lookup_data(self): """ Generate lookup data for self-commonlines. - - :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). - :param eq_idx: Equator index mask for Ris. - :param eq_class: Equator classification for Ris. """ + # Get self-commonline angles. self.scl_angles1 = self.generate_scl_angles( self.inplane_rotated_grid1, self.eq_idx1, @@ -278,12 +319,50 @@ def generate_scl_lookup_data(self): self.eq_class2, ) - self.scl_ind_1, self.scl_eq_lin_idx_lists_1 = self.generate_scl_indices( + # Get self-commonline indices. + self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self.generate_scl_indices( self.scl_angles1, self.eq_class1 ) - self.scl_ind_2, self.scl_eq_lin_idx_lists_2 = self.generate_scl_indices( + self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self.generate_scl_indices( self.scl_angles2, self.eq_class2 ) + self.scl_idx_lists = np.concatenate( + (self.scl_eq_lin_idx_lists_1, self.scl_eq_lin_idx_lists_2), axis=1 + ) + + # Compute non-equator indices. + # Register non equator indices. Denote by C_ij the j'th in-plane rotation of + # the i'th ML candidate, and arrange all candidates in a list with their in-plane + # rotations in the order: C_11,...,C_1r,...,C_m1,...,C_mr where m is the + # number of candidates and r is the number of in plane rotations. Here we + # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are + # non equators then we have the sub list is + # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. + n_non_eq = np.sum(self.eq_class1 == 0) + np.sum(self.eq_class2 == 0) + non_eq_idx = np.zeros((n_non_eq, int(self.n_inplane_rots))) + non_eq_idx[:, 0] = ( + np.hstack( + ( + np.where(self.eq_class1 == 0)[0], + len(self.eq_class1) + np.where(self.eq_class2 == 0)[0], + ) + ) + * self.n_inplane_rots + ) + for i in range(1, self.n_inplane_rots): + non_eq_idx[:, i] = non_eq_idx[:, 0] + i + + self.non_eq_idx = non_eq_idx.astype(int) + + # Non-topview equator indices. + non_tv_eq_idx = np.concatenate( + ( + np.where(self.eq_class1 > 0)[0], + len(self.eq_class1) + np.where(self.eq_class2 > 0)[0], + ) + ) + + self.non_tv_eq_idx = non_tv_eq_idx.astype(int) def generate_scl_angles(self, Ris, eq_idx, eq_class): """ @@ -395,15 +474,15 @@ def generate_scl_indices(self, scl_angles, eq_class): idx1 = self.circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) idx2 = self.circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) - # Adjust so idx2 is in [0, 180) range. - idx1[idx2 >= 180] = (idx1[idx2 >= 180] + L // 2) % L - idx2[idx2 >= 180] = (idx2[idx2 >= 180] - L // 2) % (L // 2) + # Adjust so idx1 is in [0, 180) range. + idx1[idx1 >= 180] = (idx1[idx1 >= 180] - L // 2) % (L // 2) + idx2[idx1 >= 180] = (idx2[idx1 >= 180] + L // 2) % L # register indices in list. eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( - (idx1, idx2), (L, L // 2) + (idx1, idx2), (L // 2, L) ) - eq_lin_idx_lists[1, count_eq, j] = idx2 + eq_lin_idx_lists[1, count_eq, j] = idx1 count_eq += 1 scl_indices, _ = self.generate_commonline_indices(scl_angles) @@ -637,19 +716,19 @@ def generate_commonline_indices(cl_angles): row_sub = np.round(cl_angles[:, 0]).astype("int") % 360 col_sub = np.round(cl_angles[:, 1]).astype("int") % 360 - # Restrict Rj in-plane coordinates to <180 degrees. - is_geq_than_pi = col_sub >= 180 - col_sub[is_geq_than_pi] = col_sub[is_geq_than_pi] - 180 - row_sub[is_geq_than_pi] = (row_sub[is_geq_than_pi] + 180) % 360 + # Restrict Ri in-plane coordinates to <180 degrees. + is_geq_than_pi = row_sub >= 180 + row_sub[is_geq_than_pi] = row_sub[is_geq_than_pi] - 180 + col_sub[is_geq_than_pi] = (col_sub[is_geq_than_pi] + 180) % 360 # Convert to linear indices in 360*180 correlation matrix (same as cls_lookup in matlab) - cl_ind = np.ravel_multi_index((row_sub, col_sub), dims=(360, 180)) + cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(180, 360)) # Reshape cl_angles (to match matlab `cls`) cl_angles = cl_angles.reshape(og_shape) # Return as integer indices. - return cl_ind, np.rint(cl_angles).astype(int) + return cl_idx, np.rint(cl_angles).astype(int) def _generate_gs(self): """ From bd053852a6149ab97b2bf100441a06481009f5a1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 14 Mar 2024 14:23:26 -0400 Subject: [PATCH 014/105] precompute shifted polar fourier. Confirm corrs reshape. --- src/aspire/abinitio/commonline_d2.py | 58 +++++++++++++++------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 2dfcfe0918..117e43b7ee 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -71,15 +71,41 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ + self.compute_shifted_pf() self.generate_lookup_data() self.generate_scl_lookup_data() self.compute_scl_scores() + def compute_shifted_pf(self): + pf = self.pf + + # Generate shift phases. + r_max = pf.shape[-1] + shifts, shift_phases, _ = self._generate_shift_phase_and_filter( + r_max, self.max_shift, self.shift_step + ) + self.n_shifts = len(shifts) + + # Reconstruct full polar Fourier for use in correlation. + pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. + self.pf_full = PolarFT.half_to_full(pf) + + # Pre-compute shifted pf's. + pf_shifted = (pf * shift_phases[:, None, None]).swapaxes(0, 1) + self.pf_shifted = pf_shifted.reshape( + (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) + ) + + def compute_cl_scores(self): + """ + Run common lines Maximum likelihood procedure for a D2 molecule, to find + the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. + """ + def compute_scl_scores(self): """ Compute correlations for self-commonline candidates. """ - pf = self.pf n_img = self.n_img n_theta = self.n_theta max_shift_1d = self.max_shift @@ -87,18 +113,6 @@ def compute_scl_scores(self): n_eq = len(self.non_tv_eq_idx) n_inplane = self.n_inplane_rots - # Compute the correlation over all shifts. - # Generate Shifts. - r_max = pf.shape[-1] - shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, max_shift_1d, shift_step - ) - n_shifts = len(shifts) - - # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of - # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf_full = PolarFT.half_to_full(pf) - # Run ML in parallel scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) M = len(scl_matrix) // 3 @@ -113,23 +127,13 @@ def compute_scl_scores(self): ) for i in trange(n_img): - pf_i = pf[i] - pf_full_i = pf_full[i] - - # Generate shifted versions of images. - pf_i_shifted = np.array( - [pf_i * shift_phase for shift_phase in shift_phases] - ) - pf_i_shifted = np.reshape(pf_i_shifted, (n_shifts * n_theta // 2, r_max)) - - # # Normalize each ray. - pf_full_i /= norm(pf_full_i, axis=1)[..., np.newaxis] - pf_i_shifted /= norm(pf_i_shifted, axis=1)[..., np.newaxis] + pf_full_i = self.pf_full[i] + pf_i_shifted = self.pf_shifted[i] # Compute max correlation over all shifts. corrs = np.real(pf_i_shifted @ np.conj(pf_full_i).T) - corrs = np.reshape(corrs, (n_theta // 2, n_shifts, n_theta)) - corrs = np.max(corrs, axis=1) + corrs = np.reshape(corrs, (self.n_shifts, n_theta // 2, n_theta)) + corrs = np.max(corrs, axis=0) # Map correlations to probabilities (in the spirit of Maximum Likelihood). corrs = 0.5 * (corrs + 1) From a336b1bac72da92b8d081490b2eef1e5f4e97cfd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 Mar 2024 13:44:17 -0400 Subject: [PATCH 015/105] add generate_scl_scores_idx_map method. --- src/aspire/abinitio/commonline_d2.py | 66 ++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 117e43b7ee..5337a5fa8e 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -286,7 +286,7 @@ def generate_lookup_data(self): ) # Generate commmonline angles induced by all relative rotation candidates. - cl_angles1 = self.generate_commonline_angles( + cl_angles1, self.eq2eq_Rij_table_11 = self.generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid1, self.eq_idx1, @@ -294,7 +294,7 @@ def generate_lookup_data(self): self.eq_class1, self.eq_class1, ) - cl_angles2 = self.generate_commonline_angles( + cl_angles2, self.eq2eq_Rij_table_12 = self.generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid2, self.eq_idx1, @@ -368,6 +368,8 @@ def generate_scl_lookup_data(self): self.non_tv_eq_idx = non_tv_eq_idx.astype(int) + self.generate_scl_scores_idx_map() + def generate_scl_angles(self, Ris, eq_idx, eq_class): """ Generate self-commonline angles. @@ -493,6 +495,64 @@ def generate_scl_indices(self, scl_angles, eq_class): return scl_indices, eq_lin_idx_lists + def generate_scl_scores_idx_map(self): + n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) + n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) + + # First the map for i 0] + if len(unique_pairs_i) == 0: + continue + i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + + for j in unique_pairs_i: + j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) + oct1_ij_map[idx] = np.vstack((i_idx_plus_offset, j_idx_plus_offset)) + idx += 1 + + # First the map for i 0] + if len(unique_pairs_i) == 0: + continue + i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + + for j in unique_pairs_i: + j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) + oct2_ij_map[idx] = np.vstack((i_idx_plus_offset, j_idx_plus_offset)) + idx += 1 + + tmp1 = oct1_ij_map[:, 0, :] + tmp2 = oct1_ij_map[:, 1, :] + self.oct1_ij_map = np.column_stack((tmp1.flatten(), tmp2.flatten())) + + tmp1 = oct2_ij_map[:, 0, :] + tmp2 = oct2_ij_map[:, 1, :] + self.oct2_ij_map = np.column_stack((tmp1.flatten(), tmp2.flatten())) + @staticmethod def circ_seq(n1, n2, L): """ @@ -706,7 +766,7 @@ def generate_commonline_angles( cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) cl_angles = cl_angles * 180 / np.pi - return cl_angles + return cl_angles, eq2eq_Rij_table @staticmethod def generate_commonline_indices(cl_angles): From 9e089c6fbff673aa60fb82a8096a25a70ce2a0c2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 Mar 2024 13:47:06 -0400 Subject: [PATCH 016/105] remove unused variables. --- src/aspire/abinitio/commonline_d2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5337a5fa8e..08a984768a 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -108,8 +108,6 @@ def compute_scl_scores(self): """ n_img = self.n_img n_theta = self.n_theta - max_shift_1d = self.max_shift - shift_step = self.shift_step n_eq = len(self.non_tv_eq_idx) n_inplane = self.n_inplane_rots From a8e618ca8d03e321ad5104fd1a4fb86059f5b24b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 29 Mar 2024 16:07:28 -0400 Subject: [PATCH 017/105] completed compute_cl_scores --- src/aspire/abinitio/commonline_d2.py | 196 ++++++++++++++++++++++++--- 1 file changed, 174 insertions(+), 22 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 08a984768a..223824187f 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -5,7 +5,7 @@ from aspire.abinitio import CLOrient3D from aspire.operators import PolarFT -from aspire.utils import Rotation, trange +from aspire.utils import Rotation, tqdm, trange logger = logging.getLogger(__name__) @@ -75,8 +75,12 @@ def estimate_rotations(self): self.generate_lookup_data() self.generate_scl_lookup_data() self.compute_scl_scores() + self.compute_cl_scores() def compute_shifted_pf(self): + """ + Pre-compute shifted and full polar Fourier transforms. + """ pf = self.pf # Generate shift phases. @@ -101,6 +105,137 @@ def compute_cl_scores(self): Run common lines Maximum likelihood procedure for a D2 molecule, to find the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. """ + # Map the self common line scores of each 2 candidate rotations R_i,R_j to + # the respective relative rotation candidate R_i^TR_j. + n_lookup_1 = len(self.scl_idx_1) // 3 + oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) + oct2_ij_map = self.oct2_ij_map + oct2_ij_map[:, 1] += n_lookup_1 + oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) + ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) + + # Allocate output variables. + n_pairs = self.n_img * (self.n_img - 1) // 2 + corrs_idx = np.zeros(n_pairs, dtype=np.int64) + corrs_out = np.zeros(n_pairs, dtype=self.dtype) + ij_idx = 0 + + # Search for common lines between pairs of projections. + pbar = tqdm( + desc="Searching for commonlines between pairs of images", total=n_pairs + ) + for i in range(self.n_img): + pf_i = self.pf_shifted[i] + scores_i = self.scls_scores[i] + + for j in range(i + 1, self.n_img): + pf_j = self.pf_full[j] + + # Compute maximum correlation over all shifts. + corrs = np.real(pf_i @ np.conj(pf_j).T) + corrs = np.reshape( + corrs, (self.n_shifts, self.n_theta // 2, self.n_theta) + ) + corrs = np.max(corrs, axis=0) + + # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. + cl_idx = np.unravel_index( + self.cl_idx, (self.n_theta // 2, self.n_theta) + ) + prod_corrs = corrs[cl_idx] + prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) + prod_corrs = np.prod(prod_corrs, axis=1) + + # Incorporate scores of individual rotations from self-commonlines. + scores_j = self.scls_scores[j] + scores_ij = scores_i[ij_map[:, 0]] * scores_j[ij_map[:, 1]] + + # Find maximum correlations. + prod_corrs = prod_corrs * scores_ij + max_idx = np.argmax(prod_corrs) + corrs_idx[ij_idx] = max_idx + corrs_out[ij_idx] = prod_corrs[max_idx] + ij_idx += 1 + + pbar.update() + pbar.close() + + # Get estimated relative viewing directions. + self.Rijs_est = self.get_Rijs_from_lin_idx(corrs_idx) + + def get_Rijs_from_lin_idx(self, lin_idx): + """ + Restore map results from maximum-likelihood over commonlines to corresponding + relative rotations. + """ + Rijs_est = np.zeros((len(lin_idx), 4, 3, 3), dtype=self.dtype) + n_cand_per_oct = len(self.cl_idx_1) // 4 + oct1_idx = lin_idx < n_cand_per_oct + n_est_in_oct1 = np.sum(oct1_idx, dtype=int) + if n_est_in_oct1 > 0: + Rijs_est[oct1_idx] = self.get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) + if n_est_in_oct1 <= len(lin_idx): + Rijs_est[~oct1_idx] = self.get_Rijs_from_oct( + lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 + ) + + def get_Rijs_from_oct(self, lin_idx, octant=1): + if octant not in [1, 2]: + raise ValueError("`octant` must be 1 or 2.") + + # Get pairs lookup table. + if octant == 1: + unique_pairs = self.eq2eq_Rij_table_11 + else: + unique_pairs = self.eq2eq_Rij_table_12 + n_theta = self.n_inplane_rots + n_lookup_pairs = np.sum(unique_pairs, dtype=np.int64) + n_rots = len(self.sphere_grid1) + if octant == 1: + n_rots2 = n_rots + else: + n_rots2 = len(self.sphere_grid2) + n_pairs = len(lin_idx) + + # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. + p_idx, inplane_i, inplane_j = np.unravel_index( + lin_idx, (2 * n_lookup_pairs, n_theta, n_theta // 2) + ) + transpose_idx = p_idx >= n_lookup_pairs + p_idx[transpose_idx] -= n_lookup_pairs + s = self.inplane_rotated_grid1.shape + inplane_rotated_grid = np.reshape( + self.inplane_rotated_grid1, (np.prod(s[0:2]), 3, 3) + ) + if octant == 1: + s2 = s + inplane_rotated_grid2 = inplane_rotated_grid + else: + s2 = self.inplane_rotated_grid2.shape + inplane_rotated_grid2 = np.reshape( + self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) + ) + + Rijs_est = np.zeros((n_pairs, 4, 3, 3), dtype=self.dtype) + + # Convert linear indices of unique table to linear indices of index pairs table. + idx_vec = np.arange(np.prod(unique_pairs.shape)) + unique_lin_idx = idx_vec[unique_pairs.flatten()] + I, J = np.unravel_index(unique_lin_idx, (n_rots, n_rots2)) + est_idx = np.vstack((I[p_idx], J[p_idx])) + + # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. + Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) + Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) + Ris = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) + Rjs = np.transpose(inplane_rotated_grid2[Rjs_lin_idx], (0, 2, 1)) + + for k, g in enumerate(self.gs): + Rijs_est[:, k] = np.transpose(Ris, (0, 2, 1)) @ (g * Rjs) + + Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) + + return Rijs_est def compute_scl_scores(self): """ @@ -299,11 +434,13 @@ def generate_lookup_data(self): self.eq_idx2, self.eq_class1, self.eq_class2, + triu=False, ) # Generate commonline indices. self.cl_idx_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) self.cl_idx_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) + self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) def generate_scl_lookup_data(self): """ @@ -366,6 +503,7 @@ def generate_scl_lookup_data(self): self.non_tv_eq_idx = non_tv_eq_idx.astype(int) + # Generate maps from scl indices to relative rotations. self.generate_scl_scores_idx_map() def generate_scl_angles(self, Ris, eq_idx, eq_class): @@ -499,36 +637,32 @@ def generate_scl_scores_idx_map(self): # First the map for i 0] + unique_pairs_i = idx_vec[self.eq2eq_Rij_table_11[i]] if len(unique_pairs_i) == 0: continue i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) for j in unique_pairs_i: j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) - oct1_ij_map[idx] = np.vstack((i_idx_plus_offset, j_idx_plus_offset)) + oct1_ij_map[:, :, idx] = np.column_stack( + (i_idx_plus_offset, j_idx_plus_offset) + ) idx += 1 # First the map for i Date: Thu, 11 Apr 2024 15:15:01 -0400 Subject: [PATCH 018/105] global J sync complete --- src/aspire/abinitio/commonline_d2.py | 209 ++++++++++++++++++++++++++- 1 file changed, 208 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 223824187f..0c8a402889 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -5,7 +5,8 @@ from aspire.abinitio import CLOrient3D from aspire.operators import PolarFT -from aspire.utils import Rotation, tqdm, trange +from aspire.utils import J_conjugate, Rotation, all_pairs, all_triplets, tqdm, trange +from aspire.utils.random import randn logger = logging.getLogger(__name__) @@ -31,6 +32,7 @@ def __init__( grid_res=1200, inplane_res=5, eq_min_dist=7, + epsilon=0.01, seed=None, ): """ @@ -63,6 +65,7 @@ def __init__( self.n_inplane_rots = int(360 / self.inplane_res) self.eq_min_dist = eq_min_dist self.seed = seed + self.epsilon = epsilon self._generate_gs() def estimate_rotations(self): @@ -77,6 +80,14 @@ def estimate_rotations(self): self.compute_scl_scores() self.compute_cl_scores() + # Handedness Synchronization + self.Rijs_sync = self._global_J_sync(self.Rijs_est) + np.save("Rijs_sync.npy", self.Rijs_sync) + np.save("Rijs_est.npy", self.Rijs_est) + import pdb + + pdb.set_trace() + def compute_shifted_pf(self): """ Pre-compute shifted and full polar Fourier transforms. @@ -179,6 +190,8 @@ def get_Rijs_from_lin_idx(self, lin_idx): lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 ) + return Rijs_est + def get_Rijs_from_oct(self, lin_idx, octant=1): if octant not in [1, 2]: raise ValueError("`octant` must be 1 or 2.") @@ -691,6 +704,200 @@ def generate_scl_scores_idx_map(self): (tmp1.flatten(order="F"), tmp2.flatten(order="F")) ) + ############################# + # Methods for Global J Sync # + ############################# + + def _global_J_sync(self, Rijs): + """ + Global J-synchronization of all third row outer products. Given 3x3 matrices Rijs and viis, each + of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), + we return Rijs and viis that all have either a spurious J or not. + + :param Rijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate + might have a spurious J independently of other estimates. + + :return: Rijs, all of which have a spurious J or not. + """ + n_img = self.n_img + + # Find best J_configuration. + J_list = self._J_configuration(Rijs) + + # Determine relative handedness of Rijs. + sign_ij_J = self._J_sync_power_method(J_list) + + # Synchronize Rijs + logger.info("Applying global handedness synchronization.") + mask = sign_ij_J == 1 + Rijs[mask] = J_conjugate(Rijs[mask]) + + return Rijs + + def _J_configuration(self, Rijs): + """ + For each triplet of indices (i, j, k), consider the relative rotations + tuples {Ri^TgmRj}, {Ri^TglRk} and {Rj^TgrRk}. Compute norms of the form + ||Ri^TgmRj*Rj^TglRk-Ri^TglRk||, ||J*Ri^TgmRj*J*Rj^TglRk-Ri^TglRk||, + ||Ri^TgmRj*J*Rj^TglRk*J-Ri^TglRk| and ||Ri^TgmRj*Rj^TglRk-J*Ri^TglRk*J|| + where gm,gl,gr are the varipus gorup members of Dn and J=diag([1,1-1]). + The correct "J-configuration" is given for the smallest of these 4 norms. + + :param Rijs: (n-choose-2)x3x3 array of relative rotations. + :return: List of n-choose-3 indices in {0,1,2,3} indicating + which J-configuration for each triplet of Rijs, inmik", Rik, Rjk_t) + + arr = np.zeros((8, 8, 3, 3), dtype=self.dtype) + arr[0:4, 0:4] = prod_arr - Rij[0] + arr[0:4, 4:8] = prod_arr - Rij[1] + arr[4:8, 0:4] = prod_arr - Rij[2] + arr[4:8, 4:8] = prod_arr - Rij[3] + + arr = arr.reshape((64, 9)) + arr = np.sum(arr**2, axis=1) + m = np.sort(arr.flatten()) + vote = np.sum(m[:16]) + + return vote + + def _J_sync_power_method(self, J_list): + """ + Calculate the leading eigenvector of the J-synchronization matrix + using the power method. + + As the J-synchronization matrix is of size (n-choose-2)x(n-choose-2), we + use the power method to compute the eigenvalues and eigenvectors, + while constructing the matrix on-the-fly. + + :param Rijs: (n-choose-2)x3x3 array of estimates of relative orientation matrices. + + :return: An array of length n-choose-2 consisting of 1 or -1, where the sign of the + i'th entry indicates whether the i'th relative orientation matrix will be J-conjugated. + """ + + # Set power method tolerance and maximum iterations. + epsilon = self.epsilon + max_iters = 100 + + # Initialize candidate eigenvectors + n_Rijs = len(self.pairs) + vec = randn(n_Rijs, seed=self.seed) + vec = vec / norm(vec) + residual = 1 + itr = 0 + + # Power method iterations + logger.info( + "Initiating power method to estimate J-synchronization matrix eigenvector." + ) + while itr < max_iters and residual > epsilon: + itr += 1 + vec_new = self._signs_times_v2(J_list, vec) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec) + + return J_sync + + def _signs_times_v2(self, J_list, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector. + + The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of + nodes consists of the estimates Rijs and whose set of edges consists of the undirected edges between + all triplets of estimates vij, vjk, and vik, where i Date: Fri, 12 Apr 2024 16:00:11 -0400 Subject: [PATCH 019/105] beginning of color sync --- src/aspire/abinitio/commonline_d2.py | 111 +++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 0c8a402889..8f6db5f855 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -67,6 +67,8 @@ def __init__( self.seed = seed self.epsilon = epsilon self._generate_gs() + self.triplets = all_triplets(self.n_img) + self.pairs, self.pairs_to_linear = all_pairs(self.n_img, return_map=True) def estimate_rotations(self): """ @@ -74,19 +76,23 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ + # Pre-compute phase-shifted polar Fourier. self.compute_shifted_pf() + + # Generate lookup data self.generate_lookup_data() self.generate_scl_lookup_data() + + # Compute common-line scores. self.compute_scl_scores() + + # Compute common-lines and estimate relative rotations Rijs. self.compute_cl_scores() - # Handedness Synchronization + # Perform handedness synchronization. self.Rijs_sync = self._global_J_sync(self.Rijs_est) - np.save("Rijs_sync.npy", self.Rijs_sync) - np.save("Rijs_est.npy", self.Rijs_est) - import pdb - pdb.set_trace() + # Synchronize colors. def compute_shifted_pf(self): """ @@ -714,7 +720,7 @@ def _global_J_sync(self, Rijs): of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), we return Rijs and viis that all have either a spurious J or not. - :param Rijs: An (n-choose-2)x3x3 array where each 3x3 slice holds an estimate for the corresponding + :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate for the corresponding outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate might have a spurious J independently of other estimates. @@ -898,6 +904,99 @@ def _signs_times_v2(self, J_list, vec): return new_vec + ###################### + # Synchronize Colors # + ###################### + + def _sync_colors(self, Rijs): + + # Generate array of one rank matrices from which we can extract rows. + # Matrices are of the form 0.5(Ri^TRj+Ri^TgkRj). Each such matrix can be + # written in the form Qi^T*Ik*Qj where Ik is a 3x3 matrix with all zero + # entries except for the entry a_kk, k in {1,2,3}. + n_pairs = len(Rijs) + Rijs_rows = np.zeros((n_pairs, 3, 3, 3), dtype=self.dtype) + for layer in range(3): + Rijs_rows[:, layer] = 0.5 * (Rijs[:, 0] + Rijs[:, layer + 1]) + + # Partition the set of matrices Rijs_rows into 3 sets of matrices, where + # each set there are only matrices Qi^T*Ik*Qj for a unique value of k in + # {1,2,3}. + # First determine for each pair of tuples of the form {Qi^T*Ik*Qj} and + # {Qr^T*Il*Qj} where {i,j}\cap{r,l}==1, whether l==r. + color_perms = self._match_colors(Rijs_rows) + return color_perms + + def _match_colors(self, Rijs_rows): + Rijs_rows_t = np.transpose(Rijs_rows, (0, 1, 3, 2)) + trip_perms = [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]] + inverse_perms = [ + [1, 2, 3], + [1, 3, 2], + [2, 1, 3], + [3, 1, 2], + [2, 3, 1], + [3, 2, 1], + ] + + m = np.zeros(6) + colors_i = np.zeros((len(self.triplets), 3), dtype=self.dtype) # int? + n_trip = len(self.triplets) + votes = np.zeros((n_trip)) + + # Compute relative color permutations. See Section 7.2 of paper. + for i, j, k in self.triplets: + ij = self.pairs_to_linear[i, j] + jk = self.pairs_to_linear[j, k] + ik = self.pairs_to_linear[i, k] + + # For r=1:3 compute 3*3 products v_{ji}(r)v_{ik}v_{kj} + prod_arr = np.einsum("nij,mjk->mnik", Rijs_rows[ik], Rijs_rows_t[jk]) + prod_arr_tmp = prod_arr.copy() + prod_arr = np.einsum( + "nij,mjk->nmik", Rijs_rows_t[ij], prod_arr.reshape((9, 3, 3)) + ) + prod_arr = np.transpose( + prod_arr.reshape((3, 3, 3, 9), order="F"), (2, 1, 0, 3) + ) + + # Compare to v_{jj}(r)=v_{ji}v_{ij}. + self_prods = Rijs_rows_t[ij] @ Rijs_rows[ij] + self_prods = self_prods.reshape(3, 9) + + prod_arr1 = prod_arr.copy() + prod_arr1[:, :, 0, :] = prod_arr1[:, :, 0, :] - self_prods[0] + prod_arr1[:, :, 1, :] = prod_arr1[:, :, 1, :] - self_prods[1] + prod_arr1[:, :, 2, :] = prod_arr1[:, :, 2, :] - self_prods[2] + norms1 = np.sum(prod_arr1**2, axis=3) + + prod_arr2 = prod_arr.copy() + prod_arr2[:, :, 0, :] = prod_arr2[:, :, 0, :] + self_prods[0] + prod_arr2[:, :, 1, :] = prod_arr2[:, :, 1, :] + self_prods[1] + prod_arr2[:, :, 2, :] = prod_arr2[:, :, 2, :] + self_prods[2] + norms2 = np.sum(prod_arr2**2, axis=3) + + # Compare to v_{jj}(r)=v_{jk}v_{kj}. + self_prods = Rijs_rows[jk] @ Rijs_rows_t[jk] + self_prods = self_prods.reshape(3, 9) + + prod_arr1 = prod_arr.copy() + prod_arr1[0] = prod_arr1[0] - self_prods[0] + prod_arr1[1] = prod_arr1[1] - self_prods[1] + prod_arr1[2] = prod_arr1[2] - self_prods[2] + norms1 = norms1 + np.sum(prod_arr1**2, axis=3) + + prod_arr2 = prod_arr.copy() + prod_arr2[0] = prod_arr2[0] + self_prods[0] + prod_arr2[1] = prod_arr2[1] + self_prods[1] + prod_arr2[2] = prod_arr2[2] + self_prods[2] + norms2 = norms2 + np.sum(prod_arr2**2, axis=3) + # Verfied up to this point! + + #################### + # Helper Functions # + #################### + @staticmethod def circ_seq(n1, n2, L): """ From 1c5593809765564bd37774238d014bcb9ccef443 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 17 Apr 2024 13:08:35 -0400 Subject: [PATCH 020/105] _match_colors function. --- src/aspire/abinitio/commonline_d2.py | 91 ++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 12 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 8f6db5f855..382b3eeb7e 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -929,20 +929,27 @@ def _sync_colors(self, Rijs): def _match_colors(self, Rijs_rows): Rijs_rows_t = np.transpose(Rijs_rows, (0, 1, 3, 2)) - trip_perms = [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]] - inverse_perms = [ - [1, 2, 3], - [1, 3, 2], - [2, 1, 3], - [3, 1, 2], - [2, 3, 1], - [3, 2, 1], - ] + trip_perms = np.array( + [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]], + dtype="int", + ) + inverse_perms = np.array( + [ + [0, 1, 2], + [0, 2, 1], + [1, 0, 2], + [2, 0, 1], + [1, 2, 0], + [2, 1, 0], + ], + dtype="int", + ) - m = np.zeros(6) - colors_i = np.zeros((len(self.triplets), 3), dtype=self.dtype) # int? + m = np.zeros((6, 6), dtype=self.dtype) + colors_i = np.zeros((len(self.triplets), 3), dtype=self.dtype) # ints? n_trip = len(self.triplets) votes = np.zeros((n_trip)) + trip_idx = 0 # Compute relative color permutations. See Section 7.2 of paper. for i, j, k in self.triplets: @@ -991,7 +998,67 @@ def _match_colors(self, Rijs_rows): prod_arr2[1] = prod_arr2[1] + self_prods[1] prod_arr2[2] = prod_arr2[2] + self_prods[2] norms2 = norms2 + np.sum(prod_arr2**2, axis=3) - # Verfied up to this point! + + # For r=1:3 compute 3*3 products v_{ij}(r)v_{jk}v_{ki} and compare to + # Compare to v_{ii}(r)=v_{ij}v_{ji} + prod_arr = np.transpose(prod_arr_tmp, (0, 1, 3, 2)) + prod_arr = np.einsum( + "nij,mjk->mnik", Rijs_rows[ij], prod_arr.reshape(9, 3, 3) + ) + prod_arr = np.transpose( + prod_arr.reshape((3, 3, 3, 9), order="F"), (1, 0, 2, 3) + ) + # Commented out calculations in matlab here. + + # Compare to v_{ii}(r)=v_{ik}v_{ki}. + self_prods = Rijs_rows[ik] @ Rijs_rows_t[ik] + self_prods = self_prods.reshape(3, 9) + + prod_arr1 = prod_arr.copy() + prod_arr1[:, 0] = prod_arr1[:, 0] - self_prods[0] + prod_arr1[:, 1] = prod_arr1[:, 1] - self_prods[1] + prod_arr1[:, 2] = prod_arr1[:, 2] - self_prods[2] + norms1 = norms1 + np.sum(prod_arr1**2, axis=3) + + prod_arr2 = prod_arr.copy() + prod_arr2[:, 0] = prod_arr2[:, 0] + self_prods[0] + prod_arr2[:, 1] = prod_arr2[:, 1] + self_prods[1] + prod_arr2[:, 2] = prod_arr2[:, 2] + self_prods[2] + norms2 = norms2 + np.sum(prod_arr2**2, axis=3) + + norms = np.minimum(norms1, norms2) + + for l in range(6): + p1 = trip_perms[l] + for r in range(6): + p2 = trip_perms[r] + m[l, r] = ( + norms[p2[0], p1[0], 0] + + norms[p2[1], p1[1], 1] + + norms[p2[2], p1[2], 2] + ) + + min_idx = np.unravel_index(np.argmin(m), m.shape) + votes[trip_idx] = m[min_idx] + colors_i[trip_idx, :2] = [ + 100 * (min_idx[0] + 1), + 10 * (min_idx[1] + 1), + ] # What's up with 100 and 10?? + # might need to use 1-based indexing for colors_i, ie min_idx[i] + 1 + + # Calculate the relative permutation of Rik to Rij given + # by (sigma_ik)\circ(sigma_ij)^-1 + inv_jk_perm = inverse_perms[min_idx[1]] + rel_perm = trip_perms[min_idx[0]] + rel_perm = rel_perm[inv_jk_perm] + colors_i[trip_idx, 2] = (2 * (rel_perm[0] + 1) - 1) + ( + rel_perm[1] > rel_perm[2] + ) + trip_idx += 1 + + colors_i = np.sum(colors_i, axis=1) + + return colors_i #################### # Helper Functions # From b879d445457302a586e902b3bce98278fa14e788 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 17 Apr 2024 15:50:00 -0400 Subject: [PATCH 021/105] mult_cmat_by_vec function --- src/aspire/abinitio/commonline_d2.py | 95 +++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 382b3eeb7e..eaab08b9cf 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1,6 +1,7 @@ import logging import numpy as np +import scipy.sparse.linalg as la from numpy.linalg import norm from aspire.abinitio import CLOrient3D @@ -925,6 +926,21 @@ def _sync_colors(self, Rijs): # First determine for each pair of tuples of the form {Qi^T*Ik*Qj} and # {Qr^T*Il*Qj} where {i,j}\cap{r,l}==1, whether l==r. color_perms = self._match_colors(Rijs_rows) + + # Compute eigenvectors of color matrix. This is just a matrix of dimensions + # 3(N choose 2)x3(N choose 2) where each entry corresponds to a pair of + # matrices {Qi^T*Ir*Qj} and {Qr^T*Il*Qj} and eqauls \delta_rl. + # The 2 leading eigenvectors span a linear subspace which contains a + # vector which encodes the partition. All the entries of the vector are + # either 1,0 or -1, where the number encodes which the index r in Ir. + # This vector is a linear combination of the two leading eigen vectors, + # and so we 'unmix' these vectors to retrieve it. + cmat = lambda v: self.mult_cmat_by_vec(color_perms, v) + omega = la.LinearOperator((3 * n_pairs,) * 2, cmat) + vals, colors = la.eigs(omega, k=3, which="LR") + import pdb + + pdb.set_trace() return color_perms def _match_colors(self, Rijs_rows): @@ -1040,11 +1056,12 @@ def _match_colors(self, Rijs_rows): min_idx = np.unravel_index(np.argmin(m), m.shape) votes[trip_idx] = m[min_idx] + + # Store permutation indices as digits in of base 10 number. colors_i[trip_idx, :2] = [ 100 * (min_idx[0] + 1), 10 * (min_idx[1] + 1), - ] # What's up with 100 and 10?? - # might need to use 1-based indexing for colors_i, ie min_idx[i] + 1 + ] # Calculate the relative permutation of Rik to Rij given # by (sigma_ik)\circ(sigma_ij)^-1 @@ -1060,6 +1077,80 @@ def _match_colors(self, Rijs_rows): return colors_i + def mult_cmat_by_vec(self, c_perms, v): + """ + Multiply color matrix by vector v "on the fly". + + :param c_perms: An (N over 3) vector. Each corresponds to a triplet of + indices i Date: Thu, 18 Apr 2024 12:47:50 -0400 Subject: [PATCH 022/105] fix loop indices and lambda function. --- src/aspire/abinitio/commonline_d2.py | 30 +++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index eaab08b9cf..9cddddd466 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -727,8 +727,6 @@ def _global_J_sync(self, Rijs): :return: Rijs, all of which have a spurious J or not. """ - n_img = self.n_img - # Find best J_configuration. J_list = self._J_configuration(Rijs) @@ -935,12 +933,11 @@ def _sync_colors(self, Rijs): # either 1,0 or -1, where the number encodes which the index r in Ir. # This vector is a linear combination of the two leading eigen vectors, # and so we 'unmix' these vectors to retrieve it. - cmat = lambda v: self.mult_cmat_by_vec(color_perms, v) - omega = la.LinearOperator((3 * n_pairs,) * 2, cmat) - vals, colors = la.eigs(omega, k=3, which="LR") - import pdb + color_mat = la.LinearOperator( + (3 * n_pairs,) * 2, lambda v: self.mult_cmat_by_vec(color_perms, v) + ) + vals, colors = la.eigs(color_mat, k=3, which="LR") - pdb.set_trace() return color_perms def _match_colors(self, Rijs_rows): @@ -1044,11 +1041,11 @@ def _match_colors(self, Rijs_rows): norms = np.minimum(norms1, norms2) - for l in range(6): - p1 = trip_perms[l] - for r in range(6): - p2 = trip_perms[r] - m[l, r] = ( + for r in range(6): + p1 = trip_perms[r] + for s in range(6): + p2 = trip_perms[s] + m[r, s] = ( norms[p2[0], p1[0], 0] + norms[p2[1], p1[1], 1] + norms[p2[2], p1[2], 2] @@ -1100,18 +1097,19 @@ def mult_cmat_by_vec(self, c_perms, v): trip_idx = 0 for i in trange(self.n_img, desc="Computing cmat_times_v."): for j in range(i + 1, self.n_img - 1): - ij = 3 * self.pairs_to_linear[i, j] - 2 + ij = 3 * self.pairs_to_linear[i, j] for k in range(j + 1, self.n_img): - ik = 3 * self.pairs_to_linear[i, k] - 2 - jk = 3 * self.pairs_to_linear[j, k] - 2 + ik = 3 * self.pairs_to_linear[i, k] + jk = 3 * self.pairs_to_linear[j, k] # Extract permutation indices from c_perms n = c_perms[trip_idx] + trip_idx += 1 p_n1 = n // 100 p_n3 = n % 10 p_n2 = (n - p_n1 * 100 - p_n3) // 10 - # Adjust for 0-based indexing. (Take this out) + # Adjust for 0-based indexing. (Take this out by computing c_perms with 0-base) p_n1 = (p_n1 - 1).astype("int") p_n2 = (p_n2 - 1).astype("int") p_n3 = (p_n3 - 1).astype("int") From 902dd205bb76a14c70b7b25fc7e5f4a48cd482e9 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 19 Apr 2024 11:59:41 -0400 Subject: [PATCH 023/105] Add _unmix_colors function. --- src/aspire/abinitio/commonline_d2.py | 87 +++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 9cddddd466..09c0ecb8da 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -908,7 +908,9 @@ def _signs_times_v2(self, J_list, vec): ###################### def _sync_colors(self, Rijs): - + """ + Add documention! + """ # Generate array of one rank matrices from which we can extract rows. # Matrices are of the form 0.5(Ri^TRj+Ri^TgkRj). Each such matrix can be # written in the form Qi^T*Ik*Qj where Ik is a 3x3 matrix with all zero @@ -937,8 +939,12 @@ def _sync_colors(self, Rijs): (3 * n_pairs,) * 2, lambda v: self.mult_cmat_by_vec(color_perms, v) ) vals, colors = la.eigs(color_mat, k=3, which="LR") + vals = np.real(vals) + colors = np.real(colors) + + cp, _ = self._unmix_colors(colors[:, :2]) - return color_perms + return cp def _match_colors(self, Rijs_rows): Rijs_rows_t = np.transpose(Rijs_rows, (0, 1, 3, 2)) @@ -1095,7 +1101,7 @@ def mult_cmat_by_vec(self, c_perms, v): ) out = np.zeros_like(v) trip_idx = 0 - for i in trange(self.n_img, desc="Computing cmat_times_v."): + for i in range(self.n_img): for j in range(i + 1, self.n_img - 1): ij = 3 * self.pairs_to_linear[i, j] for k in range(j + 1, self.n_img): @@ -1149,6 +1155,81 @@ def mult_cmat_by_vec(self, c_perms, v): out[jk + 2] = out[jk + 2] - v[p[0]] - v[p[1]] + v[p[2]] return out + def _unmix_colors(self, color_vecs): + """ + The 'color vector' which partitions the rank 1 3x3 matrices into 3 sets + is one of 2 leading orthogonal eigenvectors of the color matrix. + SVD retrieves two orthogonal linear combinations of these vectors which + can be 'unmixed' to retrieve the color vector by finding a suitable + 2D rotation of these vectors (see Section 7.3 of D2 paper for details). + """ + n_p = color_vecs.shape[0] // 3 + d_theta = 360 // self.n_theta + max_t = 360 // d_theta + 1 + + def R_theta(theta): + R = np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], + dtype=self.dtype, + ) + return R + + s = float("inf") + scores = np.zeros(max_t, dtype=self.dtype) + idx = 0 + for t in np.arange(0, max_t, 0.5): + unmix_ev = color_vecs @ R_theta(np.pi * t / 180) + s1 = unmix_ev[:, 0].reshape(n_p, 3) + p11 = (-s1).argsort(axis=1) # descending argsort + s1 = np.take_along_axis(s1, p11, axis=1) + score11 = np.sum((s1[:, 0] + s1[:, 2]) ** 2 + s1[:, 1] ** 2) + + s2 = abs(unmix_ev[:, 1].reshape(n_p, 3)) + p12 = (-s2).argsort(axis=1) # descending argsort + s2 = np.take_along_axis(s2, p12, axis=1) + score12 = np.sum( + (s2[:, 0] - 2 * s2[:, 1]) ** 2 + + (s2[:, 0] - 2 * s2[:, 2]) ** 2 + + (s2[:, 1] - s2[:, 2]) ** 2 + ) # Matlab comment: Is this an error??? + instead of - in the first 2 members + + s1 = abs(unmix_ev[:, 0].reshape(n_p, 3)) + p12 = (-s1).argsort(axis=1) # descending argsort + s1 = np.take_along_axis(s1, p11, axis=1) + score22 = np.sum( + (s1[:, 0] - 2 * s1[:, 1]) ** 2 + + (s1[:, 0] - 2 * s1[:, 2]) ** 2 + + (s1[:, 1] - s1[:, 2]) ** 2 + ) + + s2 = unmix_ev[:, 1].reshape(n_p, 3) + p22 = (-s2).argsort(axis=1) # descending argsort + s2 = np.take_along_axis(s2, p12, axis=1) + score21 = np.sum((s2[:, 0] + s2[:, 2]) ** 2 + s2[:, 1] ** 2) + + score_vecs = [score11 + score12, score21 + score22] + which_vec = np.argmin([score11 + score12, score21 + score22]) + scores[idx] = score_vecs[which_vec] + if scores[idx] < s: + s = scores[idx] + if which_vec == 0: + p = p11 + else: + p = p22 + best_unmix = unmix_ev[:, which_vec] + + # Assign integers between 1:3 to permutations + colors = np.zeros((n_p, 3), dtype=int) + for i in range(n_p): + p_i = p[i] + p_i_sqr = p_i[p_i] + if np.sum((p_i_sqr - [0, 1, 2]) ** 2) == 0: # non-cyclic permutation + colors[i] = p_i + else: + colors[i] = p_i_sqr + + return colors.flatten(), best_unmix + #################### # Helper Functions # #################### From 7ab209afff6174d54afa2bc666161e1cce53ba6a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 19 Apr 2024 15:56:29 -0400 Subject: [PATCH 024/105] Add partial sync_signs. --- src/aspire/abinitio/commonline_d2.py | 86 +++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 09c0ecb8da..f831817e41 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -94,6 +94,10 @@ def estimate_rotations(self): self.Rijs_sync = self._global_J_sync(self.Rijs_est) # Synchronize colors. + self.colors, self.Rijs_rows = self._sync_colors(self.Rijs_sync) + + # Synchronize signs. + Ris = self._sync_signs(self.Rijs_rows, self.colors) def compute_shifted_pf(self): """ @@ -944,7 +948,7 @@ def _sync_colors(self, Rijs): cp, _ = self._unmix_colors(colors[:, :2]) - return cp + return cp, Rijs_rows def _match_colors(self, Rijs_rows): Rijs_rows_t = np.transpose(Rijs_rows, (0, 1, 3, 2)) @@ -1230,6 +1234,86 @@ def R_theta(theta): return colors.flatten(), best_unmix + ##################### + # Synchronize Signs # + ##################### + + def _sync_signs(self, rr, c_vec): + """ + This function executes the final stage of the algorithm, Signs + synchroniztion. At the end all rows of the rotations Ri are exctracted + and the matrices Ri are assembled. + """ + # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according + # to the color partition established in color synchroniztion procedure. + # The partition is stored in two different arrays each with the purpose + # of a computational speed up for two different computations performed + # later (space considerations are of little concern since arrays are ~ + # o(N^2) which doesn't pose a constraint for inputs on the scale of 10^3-10^4. + n_pairs = len(self.pairs) + c_mat_5d = np.zeros((self.n_img, self.n_img, 3, 3, 3), dtype=self.dtype) + c_mat_4d = np.zeros((n_pairs, 3, 3, 3), dtype=self.dtype) + for i in trange(self.n_img - 1): + for j in range(i + 1, self.n_img): + ij = self.pairs_to_linear[i, j] + c_mat_5d[i, j, c_vec[3 * ij]] = rr[ij, 0] + c_mat_5d[i, j, c_vec[3 * ij + 1]] = rr[ij, 1] + c_mat_5d[i, j, c_vec[3 * ij + 2]] = rr[ij, 2] + c_mat_5d[j, i, c_vec[3 * ij]] = rr[ij, 0].T + c_mat_5d[j, i, c_vec[3 * ij + 1]] = rr[ij, 1].T + c_mat_5d[j, i, c_vec[3 * ij + 2]] = rr[ij, 2].T + + c_mat_4d[ij, c_vec[3 * ij]] = rr[ij, 0] + c_mat_4d[ij, c_vec[3 * ij + 1]] = rr[ij, 1] + c_mat_4d[ij, c_vec[3 * ij + 2]] = rr[ij, 2] + + # Compute estimates for the tuples {0.5*(Ri^TRi+Ri^TgkRi), k=1:3} for + # i=1:N. For 1<=i,j<=N and c=1,2,3 write Qij^c=0.5*(Ri^TRj+Ri^TgmRj). + # For each i in {1:N} and each k in {1,2,3} the estimator is the + # average over all j~=i of Qij^c*(Qij^c)^T. + # Since in practice the result of the average is not really rank 1, we + # compute the best rank approximation to this average. + for i in range(self.n_img): + for c in range(3): + Rijs = c_mat_5d[i, :, c] + Rijs = np.delete(Rijs, i, axis=0) + Rii_est = Rijs @ np.transpose(Rijs, (0, 2, 1)) + Rii = np.mean(Rii_est, axis=0) + U, _, _ = np.linalg.svd(Rii) + c_mat_5d[i, i, c] = np.outer(U[:, 0], U[:, 0]) + + # Construct the 3Nx3N row synchroniztion matrices (as done for C_2), one + # for all first rows of the matrices Ri, one for all second rows and one + # for all third rows. The ij'th block of the k'th matrix is Qij^c. + # In C_2 one such matrix is constructed for the 3rd rows + # and is rank 1 by construction. In practice, thus far, for each c and + # (i,j) we either have Qij^c or -Qij^c independently. + c_mat = np.zeros((3, 3 * self.n_img, 3 * self.n_img), dtype=self.dtype) + rot = np.zeros((self.n_img, 3, 3), dtype=self.dtype) + for i in range(self.n_img - 1): + for j in range(i + 1, self.n_img): + ij = self.pairs_to_linear[i, j] + c_mat[c_vec[3 * ij], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ij, 0] + c_mat[c_vec[3 * ij + 1], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ + ij, 1 + ] + c_mat[c_vec[3 * ij + 2], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ + ij, 2 + ] + + c_mat[0] = c_mat[0] + c_mat[0].T + c_mat[1] = c_mat[1] + c_mat[1].T + c_mat[2] = c_mat[2] + c_mat[2].T + + for c in range(3): + for i in range(self.n_img): + c_mat[c, 3 * i : 3 * i + 2, 3 * i : 3 * i + 2] = c_mat_5d[i, i, c] + + # To decompose cMat as a rank 1 matrix we need to adjust the signs of the + # Qij^c so that sign(Qij^c*Qjk^c) = sign(Qik^c) for all c=1,2,3 and (i,j). + # In practice we compare the sign of the sum of the entries of Qij^c*Qjk^c + # to the sum of entries of Qik^c. + #################### # Helper Functions # #################### From aeae468f8a8f9ac1295216063771e05a159d271f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 25 Apr 2024 15:26:28 -0400 Subject: [PATCH 025/105] syncSigns method. --- src/aspire/abinitio/commonline_d2.py | 225 ++++++++++++++++++++++++++- 1 file changed, 220 insertions(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f831817e41..e8d3bde421 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -99,6 +99,8 @@ def estimate_rotations(self): # Synchronize signs. Ris = self._sync_signs(self.Rijs_rows, self.colors) + self.rotations = Ris + def compute_shifted_pf(self): """ Pre-compute shifted and full polar Fourier transforms. @@ -945,7 +947,7 @@ def _sync_colors(self, Rijs): vals, colors = la.eigs(color_mat, k=3, which="LR") vals = np.real(vals) colors = np.real(colors) - + colors[:, 1] = -colors[:, 1] # TODO: Take this out. Only here for debugging. cp, _ = self._unmix_colors(colors[:, :2]) return cp, Rijs_rows @@ -1293,11 +1295,11 @@ def _sync_signs(self, rr, c_vec): for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): ij = self.pairs_to_linear[i, j] - c_mat[c_vec[3 * ij], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ij, 0] - c_mat[c_vec[3 * ij + 1], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ + c_mat[c_vec[3 * ij], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ij, 0] + c_mat[c_vec[3 * ij + 1], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ ij, 1 ] - c_mat[c_vec[3 * ij + 2], 3 * i : 3 * i + 2, 3 * j : 3 * j + 2] = rr[ + c_mat[c_vec[3 * ij + 2], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ ij, 2 ] @@ -1307,13 +1309,226 @@ def _sync_signs(self, rr, c_vec): for c in range(3): for i in range(self.n_img): - c_mat[c, 3 * i : 3 * i + 2, 3 * i : 3 * i + 2] = c_mat_5d[i, i, c] + c_mat[c, 3 * i : 3 * i + 3, 3 * i : 3 * i + 3] = c_mat_5d[i, i, c] # To decompose cMat as a rank 1 matrix we need to adjust the signs of the # Qij^c so that sign(Qij^c*Qjk^c) = sign(Qik^c) for all c=1,2,3 and (i,j). # In practice we compare the sign of the sum of the entries of Qij^c*Qjk^c # to the sum of entries of Qik^c. + # For computational comfort the signs for each c=1,2,3 are stored in a + # Nx(N over 2) array, where the ij'th column corresponds to the signs of + # Qij^c * Qjk^c for k~=i,j. The entries in the k=i,j rows of the ij'th + # column are zero, the value zero is arbitrary, since these entries are + # not used by the algorithm, and only exist for comfort (of storage and + # access). + signs = np.zeros((3, n_pairs, self.n_img), dtype=self.dtype) + for c in range(3): + for p in range(n_pairs): + i, j = self.pairs[p] + idx_mask = np.full(self.n_img, True) + idx_mask[[i, j]] = False + signs[c, p, idx_mask] = self.calc_Rij_prods(c_mat_5d, i, j, c) + + # Now compute the signs of Qij^c. + est_signs = np.sign(np.sum(c_mat_4d, axis=(-2, -1))) + signs = np.transpose(signs, (0, 2, 1)) + for c in range(3): + signs[c] = est_signs[:, c] * signs[c] + + # Qik^c can be compared with Qir^c*Qrk^c for each r~=i,k, that is, + # N-2 options. Another way to look at this, is that the r'th image + # participates in all comparisons of the form sign(Qir^c*Qrk^c)~sign(Qik) + # for r~=i,k for each c=1,2,3 (see Section 8 in D2 paper). + # For each image r construct a 3Nx3N matrix. If + # sign(Qir^c*Qrk^c)~sign(Qik)=1, its ik'th 3x3 block is set to Qik, + # otherwise, it is set to -Qik. + sync_signs2 = np.arange(self.n_img).reshape((1, 1, self.n_img, 1)) + sync_signs2 = np.tile(sync_signs2, (3, self.n_img, 1, self.n_img)) + for c in range(3): + for r in range(self.n_img): + # Fill signs for synchroniztion for the r'th image. + # Go over all i,j~=r. + i_idx = np.concatenate( + (np.arange(0, r), np.arange(r + 1, self.n_img)) + ) # i~=r + for i in i_idx: + if i <= r: + j_idx = np.concatenate( + (np.arange(i + 1, r), np.arange(r + 1, self.n_img)) + ) + else: + j_idx = np.arange(i + 1, self.n_img) + for j in j_idx: + ij = self.pairs_to_linear[i, j] + sync_signs2[c, r, j, i] = ( + j + 0.5 * (1 - signs[c, r, ij]) * self.n_img + ) + sync_signs2[c, r, i, j] = ( + i + 0.5 * (1 - signs[c, r, ij]) * self.n_img + ) + # The function (1-x)/2 maps 1->0 and -1->1 + + c_mat_5d_mp = np.concatenate((c_mat_5d, -c_mat_5d), axis=1) + rows_arr = np.zeros((3, self.n_img, 3 * self.n_img), dtype=self.dtype) + svals = np.zeros((3, 2, self.n_img), dtype=self.dtype) + + logger.info("Constructing and decomposing N sign synchronization matrices...") + for c in range(3): + for r in range(self.n_img): + # Image r used for signs. + c_mat_eff = self.fill_sign_sync_matrix_c(c_mat_5d_mp, sync_signs2, c, r) + + # Construct (3*N)x(3*N) rank 1 matrices from Qik + c_mat_for_svd = np.zeros( + (3 * self.n_img, 3 * self.n_img), dtype=self.dtype + ) + for i in range(self.n_img): + row_3Nx3 = c_mat_eff[i] + row_3Nx3 = row_3Nx3.reshape(3 * self.n_img, 3) + c_mat_for_svd[:, 3 * i : 3 * i + 3] = row_3Nx3 + + c_mat_for_svd = c_mat_for_svd + c_mat_for_svd.T + + # Extract leading eigenvector of rank 1 matrix. For each r and c + # this gives an estimate for the c'th row of the rotation Rr, up + # to sign +/-. + for i in range(self.n_img): + c_mat_for_svd[3 * i : 3 * i + 3, 3 * i : 3 * i + 3] = c_mat_eff[ + i, i + ] + U, S, _ = np.linalg.svd(c_mat_for_svd) + svals[c, :, r] = S[:2] + rows_arr[c, r] = U[:, 0] + + # Sync signs according to results for each image. Dot products between + # signed row estimates are used to construct an (N over 2)x(N over 2) + # sign synchronization matrix S. If (v_i)k and (v_j)k are the i'th and + # j'th estimates for the c'th row of Rk, then the entry (i,k),(k,j) entry + # of S is <(v_i)k,(v_j)k>, where the rows and columns of S are indexed by + # double indexes (i,j), 1<=i 0: + ij_signs[zeros_idx] = 1 + + return np.sign(ij_signs) + + def mult_smat_by_vec(self, v, sign_mat, pairs_map): + """ + Multiplies the signs sync matrix by a vector. + """ + v_out = np.zeros_like(v) + for i in range(self.n_img): + for j in range(i + 1, self.n_img): + ij = self.pairs_to_linear[i, j] + v_out[ij] = sign_mat[ij] @ v[pairs_map[ij]] + return v_out + #################### # Helper Functions # #################### From 7917332fbe2e76b338d5320e8083ccbb1a061de4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 1 May 2024 15:37:44 -0400 Subject: [PATCH 026/105] max_shift_1d, mask parameter. --- src/aspire/abinitio/commonline_d2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index e8d3bde421..5c2bc6f4ba 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -35,6 +35,7 @@ def __init__( eq_min_dist=7, epsilon=0.01, seed=None, + mask=True, ): """ Initialize object for estimating 3D orientations for molecules with D2 symmetry. @@ -59,6 +60,7 @@ def __init__( n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, + mask=mask, ) self.grid_res = grid_res @@ -109,8 +111,9 @@ def compute_shifted_pf(self): # Generate shift phases. r_max = pf.shape[-1] + max_shift_1d = np.ceil(2 * np.sqrt(2) * self.max_shift) shifts, shift_phases, _ = self._generate_shift_phase_and_filter( - r_max, self.max_shift, self.shift_step + r_max, max_shift_1d, self.shift_step ) self.n_shifts = len(shifts) From 8fcb8ee624a0746453da6da8a0e9ad7369b20169 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 8 May 2024 15:29:52 -0400 Subject: [PATCH 027/105] Fix bugs. Output rotations matching matlab sometimes. Need to stabilize all eigs and svds. --- src/aspire/abinitio/commonline_d2.py | 39 ++++++++++++++++++---------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5c2bc6f4ba..7f9d53d097 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -118,7 +118,12 @@ def compute_shifted_pf(self): self.n_shifts = len(shifts) # Reconstruct full polar Fourier for use in correlation. + pf[:, :, 0] = 0 # Matching matlab convention to zero out the lowest frequency. pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. + pf *= ( + np.sqrt(2) / 2 + ) # Magic number to match matlab pf. (root2 over 2) Remove after debug. + pf = pf[:, :, ::-1] # also to match matlab self.pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. @@ -159,7 +164,7 @@ def compute_cl_scores(self): pf_j = self.pf_full[j] # Compute maximum correlation over all shifts. - corrs = np.real(pf_i @ np.conj(pf_j).T) + corrs = 2 * np.real(pf_i @ np.conj(pf_j).T) corrs = np.reshape( corrs, (self.n_shifts, self.n_theta // 2, self.n_theta) ) @@ -169,6 +174,7 @@ def compute_cl_scores(self): cl_idx = np.unravel_index( self.cl_idx, (self.n_theta // 2, self.n_theta) ) + prod_corrs = corrs[cl_idx] prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) prod_corrs = np.prod(prod_corrs, axis=1) @@ -217,6 +223,7 @@ def get_Rijs_from_oct(self, lin_idx, octant=1): unique_pairs = self.eq2eq_Rij_table_11 else: unique_pairs = self.eq2eq_Rij_table_12 + n_theta = self.n_inplane_rots n_lookup_pairs = np.sum(unique_pairs, dtype=np.int64) n_rots = len(self.sphere_grid1) @@ -256,11 +263,11 @@ def get_Rijs_from_oct(self, lin_idx, octant=1): # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) - Ris = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) - Rjs = np.transpose(inplane_rotated_grid2[Rjs_lin_idx], (0, 2, 1)) + Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) + Rjs = inplane_rotated_grid2[Rjs_lin_idx] for k, g in enumerate(self.gs): - Rijs_est[:, k] = np.transpose(Ris, (0, 2, 1)) @ (g * Rjs) + Rijs_est[:, k] = Ris_t @ (g * Rjs) Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) @@ -293,7 +300,7 @@ def compute_scl_scores(self): pf_i_shifted = self.pf_shifted[i] # Compute max correlation over all shifts. - corrs = np.real(pf_i_shifted @ np.conj(pf_full_i).T) + corrs = 2 * np.real(pf_i_shifted @ np.conj(pf_full_i).T) corrs = np.reshape(corrs, (self.n_shifts, n_theta // 2, n_theta)) corrs = np.max(corrs, axis=0) @@ -322,7 +329,7 @@ def compute_scl_scores(self): true_scls_corrs = corrs[scl_idx_list] scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] eq_measures_j = eq_measures[scls_cand_idx] - measures_agg = true_scls_corrs * eq_measures_j + measures_agg = np.outer(true_scls_corrs, eq_measures_j) k = self.non_tv_eq_idx[eq_idx] corrs_out[i, k * n_inplane + j] = np.max(measures_agg) @@ -646,8 +653,9 @@ def generate_scl_indices(self, scl_angles, eq_class): idx2 = self.circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) # Adjust so idx1 is in [0, 180) range. - idx1[idx1 >= 180] = (idx1[idx1 >= 180] - L // 2) % (L // 2) - idx2[idx1 >= 180] = (idx2[idx1 >= 180] + L // 2) % L + geq_180 = idx1 >= 180 + idx1[geq_180] = (idx1[geq_180] - L // 2) % (L // 2) + idx2[geq_180] = (idx2[geq_180] + L // 2) % L # register indices in list. eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( @@ -950,7 +958,7 @@ def _sync_colors(self, Rijs): vals, colors = la.eigs(color_mat, k=3, which="LR") vals = np.real(vals) colors = np.real(colors) - colors[:, 1] = -colors[:, 1] # TODO: Take this out. Only here for debugging. + colors = np.sign(colors[0]) * colors # Stable eigs cp, _ = self._unmix_colors(colors[:, :2]) return cp, Rijs_rows @@ -1066,10 +1074,12 @@ def _match_colors(self, Rijs_rows): + norms[p2[2], p1[2], 2] ) - min_idx = np.unravel_index(np.argmin(m), m.shape) + # In the event of duplicate min values min_idx is the first occurence + # by column order to match matlab outputs. + min_idx = np.unravel_index(np.argmin(m.T), m.shape)[::-1] votes[trip_idx] = m[min_idx] - # Store permutation indices as digits in of base 10 number. + # Store permutation indices as digits of a base 10 number. colors_i[trip_idx, :2] = [ 100 * (min_idx[0] + 1), 10 * (min_idx[1] + 1), @@ -1236,8 +1246,9 @@ def R_theta(theta): colors[i] = p_i else: colors[i] = p_i_sqr - - return colors.flatten(), best_unmix + colors = colors.flatten() + colors = 2 - colors # For debug. remove + return colors, best_unmix ##################### # Synchronize Signs # @@ -1258,7 +1269,7 @@ def _sync_signs(self, rr, c_vec): n_pairs = len(self.pairs) c_mat_5d = np.zeros((self.n_img, self.n_img, 3, 3, 3), dtype=self.dtype) c_mat_4d = np.zeros((n_pairs, 3, 3, 3), dtype=self.dtype) - for i in trange(self.n_img - 1): + for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): ij = self.pairs_to_linear[i, j] c_mat_5d[i, j, c_vec[3 * ij]] = rr[ij, 0] From e094a9557f316646388f65683ff96d142ffd3538 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 9 May 2024 14:37:21 -0400 Subject: [PATCH 028/105] Stable J_sync and SVDs. --- src/aspire/abinitio/commonline_d2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 7f9d53d097..daf7f31994 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -865,6 +865,7 @@ def _J_sync_power_method(self, J_list): # We need only the signs of the eigenvector J_sync = np.sign(vec) + J_sync = np.sign(J_sync[0]) * J_sync # Stabilize J_sync return J_sync @@ -1466,6 +1467,7 @@ def _sync_signs(self, rr, c_vec): rmatvec=lambda v, s=sign_mat: self.mult_smat_by_vec(v, s, pairs_map), ) U, S, _ = la.svds(smat, k=3, which="LM") + U = np.sign(U[0]) * U # Stable svds signs[c] = U[:, -1] # Returns in ascending order s_out[c] = S[::-1] @@ -1495,6 +1497,11 @@ def _sync_signs(self, rr, c_vec): svals2[1] = S2[::-1] svals2[2] = S3[::-1] + # Stable eigenvectors. + U1 = np.sign(U1[0]) * U1 + U2 = np.sign(U2[0]) * U2 + U3 = np.sign(U3[0]) * U3 + # The c'th row of the rotation Rj is Uc(3*j-2:3*j,1)/norm(Uc(3*j-2:3*j,1)), # (Rows must be normalized to length 1). logger.info("Assembeling rows to rotations matrices...") From 5d05f81b5d0b8530f463bb8d8ae9f92e3a568b09 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 10 May 2024 11:18:49 -0400 Subject: [PATCH 029/105] Add initial testing. --- tests/test_orient_d2.py | 180 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 tests/test_orient_d2.py diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py new file mode 100644 index 0000000000..0e7d5b19a5 --- /dev/null +++ b/tests/test_orient_d2.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest + +from aspire.abinitio import CLSymmetryD2 +from aspire.source import Simulation +from aspire.utils import ( + J_conjugate, + Rotation, + all_pairs, + cyclic_rotations, + mean_aligned_angular_distance, + randn, + utest_tolerance, +) +from aspire.volume import DnSymmetricVolume, DnSymmetryGroup + +############## +# Parameters # +############## + +DTYPE = [np.float64, np.float32] +RESOLUTION = [48, 49] +N_IMG = [10] +OFFSETS = [0] +SEED = 42 + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +def resolution(request): + return request.param + + +@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}") +def n_img(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +def offsets(request): + return request.param + + +############ +# Fixtures # +############ + + +@pytest.fixture +def source(n_img, resolution, dtype, offsets): + vol = DnSymmetricVolume( + L=resolution, order=2, C=1, K=100, dtype=dtype, seed=SEED + ).generate() + + src = Simulation( + n=n_img, + L=resolution, + vols=vol, + offsets=offsets, + amplitudes=1, + seed=SEED, + ) + + return src + + +@pytest.fixture +def orient_est(source): + orient_est = CLSymmetryD2( + source, + max_shift=0, + shift_step=1, + n_theta=360, + n_rad=source.L, + grid_res=350, # Tuned for speed + inplane_res=15, # Tuned for speed + eq_min_dist=10, # Tuned for speed + epsilon=0.01, + seed=SEED, + ) + + return orient_est + + +######### +# Tests # +######### + + +def test_estimate_rotations(orient_est): + # Estimate rotations. + orient_est.estimate_rotations() + rots_est = orient_est.rotations + + # Ground truth rotations. + rots_gt = orient_est.src.rotations + + # g-sync ground truth rotations. + rots_gt_sync = g_sync_d2(rots_est, rots_gt) + + # Register estimates to ground truth rotations and check that the + # mean angular distance between them is less than 5 degrees. + mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=5) + + +#################### +# Helper Functions # +#################### + + +def g_sync_d2(rots, rots_gt): + """ + Every estimated rotation might be a version of the ground truth rotation + rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the + ground truth rotations so that only a single global rotation need be applied + to all estimates for error analysis. + + :param rots: Estimated rotation matrices + :param rots_gt: Ground truth rotation matrices. + + :return: g-synchronized ground truth rotations. + """ + assert len(rots) == len( + rots_gt + ), "Number of estimates not equal to number of references." + n_img = len(rots) + dtype = rots.dtype + + rots_symm = DnSymmetryGroup(2, dtype).matrices + order = len(rots_symm) + + A_g = np.zeros((n_img, n_img), dtype=complex) + + pairs = all_pairs(n_img) + + for i, j in pairs: + Ri = rots[i] + Rj = rots[j] + Rij = Ri.T @ Rj + + Ri_gt = rots_gt[i] + Rj_gt = rots_gt[j] + + diffs = np.zeros(order) + for s, g_s in enumerate(rots_symm): + Rij_gt = Ri_gt.T @ g_s @ Rj_gt + diffs[s] = min( + [ + np.linalg.norm(Rij - Rij_gt), + np.linalg.norm(Rij - J_conjugate(Rij_gt)), + ] + ) + + idx = np.argmin(diffs) + + A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) + + # A_g(k,l) is exp(-j(-theta_k+theta_l)) + # Diagonal elements correspond to exp(-i*0) so put 1. + # This is important only for verification purposes that spectrum is (K,0,0,0...,0). + A_g += np.conj(A_g).T + np.eye(n_img) + + _, eig_vecs = np.linalg.eigh(A_g) + leading_eig_vec = eig_vecs[:, -1] + + angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) + rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) + + for i, rot_gt in enumerate(rots_gt): + # Since the closest ccw or cw rotation are just as good, + # we take the absolute value of the angle differences. + angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) + power_g_Ri = np.argmin(angle_dists) + rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt + + return rots_gt_sync From 2538f9a9c47b08a9a4addb20662048b1e47adae7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 10 May 2024 11:24:43 -0400 Subject: [PATCH 030/105] unused imports --- tests/test_orient_d2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 0e7d5b19a5..729afed79a 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -3,15 +3,7 @@ from aspire.abinitio import CLSymmetryD2 from aspire.source import Simulation -from aspire.utils import ( - J_conjugate, - Rotation, - all_pairs, - cyclic_rotations, - mean_aligned_angular_distance, - randn, - utest_tolerance, -) +from aspire.utils import J_conjugate, all_pairs, mean_aligned_angular_distance from aspire.volume import DnSymmetricVolume, DnSymmetryGroup ############## From 566083693e7a8565f06f0b0039854e956160dd3b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 10 May 2024 13:51:26 -0400 Subject: [PATCH 031/105] Add offsets to test. --- tests/test_orient_d2.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 729afed79a..dc6394aca1 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -10,11 +10,16 @@ # Parameters # ############## -DTYPE = [np.float64, np.float32] +DTYPE = [np.float64, pytest.param(np.float32, marks=pytest.mark.expensive)] RESOLUTION = [48, 49] N_IMG = [10] -OFFSETS = [0] -SEED = 42 +OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] + +# Since these tests are optimized for runtime, detuned parameters cause +# the algorithm to be fickle, especially for small problem sizes. +# This seed is chosen so the tests pass CI on github's envs as well +# as our self-hosted runner. +SEED = 3 @pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}") @@ -62,10 +67,17 @@ def source(n_img, resolution, dtype, offsets): @pytest.fixture def orient_est(source): + # Search for common lines over less shifts for 0 offsets. + max_shift = 0 + shift_step = 1 + if source.offsets.all() != 0: + max_shift = 0.2 + shift_step = 0.1 # Reduce shift steps for non-integer offsets of Simulation. + orient_est = CLSymmetryD2( source, - max_shift=0, - shift_step=1, + max_shift=max_shift, + shift_step=shift_step, n_theta=360, n_rad=source.L, grid_res=350, # Tuned for speed @@ -94,9 +106,12 @@ def test_estimate_rotations(orient_est): # g-sync ground truth rotations. rots_gt_sync = g_sync_d2(rots_est, rots_gt) - # Register estimates to ground truth rotations and check that the - # mean angular distance between them is less than 5 degrees. - mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=5) + # Register estimates to ground truth rotations and check that the mean angular + # distance between them is less than 5 degrees (7 when testing with offsets). + deg_tol = 5 + if orient_est.src.offsets.all() != 0: + deg_tol = 7 + mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=deg_tol) #################### From 4a299a3637a0e87ac5521000344540d75585ed4b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 13 May 2024 11:25:25 -0400 Subject: [PATCH 032/105] Make non-user functions private. --- src/aspire/abinitio/commonline_d2.py | 107 ++++++++++++++------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index daf7f31994..3da513e117 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -80,17 +80,17 @@ def estimate_rotations(self): :return: Array of rotation matrices, size n_imgx3x3. """ # Pre-compute phase-shifted polar Fourier. - self.compute_shifted_pf() + self._compute_shifted_pf() # Generate lookup data - self.generate_lookup_data() - self.generate_scl_lookup_data() + self._generate_lookup_data() + self._generate_scl_lookup_data() # Compute common-line scores. - self.compute_scl_scores() + self._compute_scl_scores() # Compute common-lines and estimate relative rotations Rijs. - self.compute_cl_scores() + self._compute_cl_scores() # Perform handedness synchronization. self.Rijs_sync = self._global_J_sync(self.Rijs_est) @@ -101,9 +101,10 @@ def estimate_rotations(self): # Synchronize signs. Ris = self._sync_signs(self.Rijs_rows, self.colors) + # Assign rotations. self.rotations = Ris - def compute_shifted_pf(self): + def _compute_shifted_pf(self): """ Pre-compute shifted and full polar Fourier transforms. """ @@ -132,7 +133,7 @@ def compute_shifted_pf(self): (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) ) - def compute_cl_scores(self): + def _compute_cl_scores(self): """ Run common lines Maximum likelihood procedure for a D2 molecule, to find the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. @@ -194,9 +195,9 @@ def compute_cl_scores(self): pbar.close() # Get estimated relative viewing directions. - self.Rijs_est = self.get_Rijs_from_lin_idx(corrs_idx) + self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) - def get_Rijs_from_lin_idx(self, lin_idx): + def _get_Rijs_from_lin_idx(self, lin_idx): """ Restore map results from maximum-likelihood over commonlines to corresponding relative rotations. @@ -206,15 +207,15 @@ def get_Rijs_from_lin_idx(self, lin_idx): oct1_idx = lin_idx < n_cand_per_oct n_est_in_oct1 = np.sum(oct1_idx, dtype=int) if n_est_in_oct1 > 0: - Rijs_est[oct1_idx] = self.get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) + Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) if n_est_in_oct1 <= len(lin_idx): - Rijs_est[~oct1_idx] = self.get_Rijs_from_oct( + Rijs_est[~oct1_idx] = self._get_Rijs_from_oct( lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 ) return Rijs_est - def get_Rijs_from_oct(self, lin_idx, octant=1): + def _get_Rijs_from_oct(self, lin_idx, octant=1): if octant not in [1, 2]: raise ValueError("`octant` must be 1 or 2.") @@ -273,7 +274,7 @@ def get_Rijs_from_oct(self, lin_idx, octant=1): return Rijs_est - def compute_scl_scores(self): + def _compute_scl_scores(self): """ Compute correlations for self-commonline candidates. """ @@ -308,7 +309,7 @@ def compute_scl_scores(self): corrs = 0.5 * (corrs + 1) # Compute equator measures. - eq_measures = self.all_eq_measures(corrs) + eq_measures = self._all_eq_measures(corrs) # Handle the cases: Non-equator, Non-top-view equator, and Top view images. # 1. Non-equators: just take product of probabilities. @@ -335,7 +336,7 @@ def compute_scl_scores(self): self.scls_scores = corrs_out - def all_eq_measures(self, corrs): + def _all_eq_measures(self, corrs): """ Compute a measure of how much an image from data is close to be an equator. """ @@ -397,13 +398,13 @@ def all_eq_measures(self, corrs): return corrs_mean * normal_corrs_max - def generate_lookup_data(self): + def _generate_lookup_data(self): """ Generate candidate relative rotations and corresponding common line indices. """ # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter # of sphere because of D2 symmetry redundancy. - sphere_grid = self.saff_kuijlaars(self.grid_res) + sphere_grid = self._saff_kuijlaars(self.grid_res) octant1_mask = np.all(sphere_grid > 0, axis=1) octant2_mask = ( (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) @@ -420,8 +421,8 @@ def generate_lookup_data(self): # We detect such directions by taking a strip of radius # eq_filter_angle about the 3 great circles perpendicular to the symmetry # axes of D2 (i.e to X,Y and Z axes). - eq_idx1, eq_class1 = self.mark_equators(sphere_grid1, self.eq_min_dist) - eq_idx2, eq_class2 = self.mark_equators(sphere_grid2, self.eq_min_dist) + eq_idx1, eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) + eq_idx2, eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) # Mark Top View Directions. # A Top view projection image is taken from the direction of one of the @@ -447,15 +448,15 @@ def generate_lookup_data(self): self.eq_class2 = eq_class2[eq_class2 < 4] # Generate in-plane rotations for each grid point on the sphere. - self.inplane_rotated_grid1 = self.generate_inplane_rots( + self.inplane_rotated_grid1 = self._generate_inplane_rots( self.sphere_grid1, self.inplane_res ) - self.inplane_rotated_grid2 = self.generate_inplane_rots( + self.inplane_rotated_grid2 = self._generate_inplane_rots( self.sphere_grid2, self.inplane_res ) # Generate commmonline angles induced by all relative rotation candidates. - cl_angles1, self.eq2eq_Rij_table_11 = self.generate_commonline_angles( + cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid1, self.eq_idx1, @@ -463,7 +464,7 @@ def generate_lookup_data(self): self.eq_class1, self.eq_class1, ) - cl_angles2, self.eq2eq_Rij_table_12 = self.generate_commonline_angles( + cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid2, self.eq_idx1, @@ -474,31 +475,31 @@ def generate_lookup_data(self): ) # Generate commonline indices. - self.cl_idx_1, self.cl_angles1 = self.generate_commonline_indices(cl_angles1) - self.cl_idx_2, self.cl_angles2 = self.generate_commonline_indices(cl_angles2) + self.cl_idx_1, self.cl_angles1 = self._generate_commonline_indices(cl_angles1) + self.cl_idx_2, self.cl_angles2 = self._generate_commonline_indices(cl_angles2) self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) - def generate_scl_lookup_data(self): + def _generate_scl_lookup_data(self): """ Generate lookup data for self-commonlines. """ # Get self-commonline angles. - self.scl_angles1 = self.generate_scl_angles( + self.scl_angles1 = self._generate_scl_angles( self.inplane_rotated_grid1, self.eq_idx1, self.eq_class1, ) - self.scl_angles2 = self.generate_scl_angles( + self.scl_angles2 = self._generate_scl_angles( self.inplane_rotated_grid2, self.eq_idx2, self.eq_class2, ) # Get self-commonline indices. - self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self.generate_scl_indices( + self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self._generate_scl_indices( self.scl_angles1, self.eq_class1 ) - self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self.generate_scl_indices( + self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self._generate_scl_indices( self.scl_angles2, self.eq_class2 ) self.scl_idx_lists = np.concatenate( @@ -540,9 +541,9 @@ def generate_scl_lookup_data(self): self.non_tv_eq_idx = non_tv_eq_idx.astype(int) # Generate maps from scl indices to relative rotations. - self.generate_scl_scores_idx_map() + self._generate_scl_scores_idx_map() - def generate_scl_angles(self, Ris, eq_idx, eq_class): + def _generate_scl_angles(self, Ris, eq_idx, eq_class): """ Generate self-commonline angles. @@ -628,7 +629,7 @@ def generate_scl_angles(self, Ris, eq_idx, eq_class): return scl_angles - def generate_scl_indices(self, scl_angles, eq_class): + def _generate_scl_indices(self, scl_angles, eq_class): L = 360 # Create candidate common line linear indices lists for equators. @@ -649,8 +650,8 @@ def generate_scl_indices(self, scl_angles, eq_class): eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) for i in non_top_view_eq_idx.tolist(): for j in range(n_inplane_rots): - idx1 = self.circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) - idx2 = self.circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) + idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) # Adjust so idx1 is in [0, 180) range. geq_180 = idx1 >= 180 @@ -664,11 +665,11 @@ def generate_scl_indices(self, scl_angles, eq_class): eq_lin_idx_lists[1, count_eq, j] = idx1 count_eq += 1 - scl_indices, _ = self.generate_commonline_indices(scl_angles) + scl_indices, _ = self._generate_commonline_indices(scl_angles) return scl_indices, eq_lin_idx_lists - def generate_scl_scores_idx_map(self): + def _generate_scl_scores_idx_map(self): n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) @@ -954,7 +955,7 @@ def _sync_colors(self, Rijs): # This vector is a linear combination of the two leading eigen vectors, # and so we 'unmix' these vectors to retrieve it. color_mat = la.LinearOperator( - (3 * n_pairs,) * 2, lambda v: self.mult_cmat_by_vec(color_perms, v) + (3 * n_pairs,) * 2, lambda v: self._mult_cmat_by_vec(color_perms, v) ) vals, colors = la.eigs(color_mat, k=3, which="LR") vals = np.real(vals) @@ -1100,7 +1101,7 @@ def _match_colors(self, Rijs_rows): return colors_i - def mult_cmat_by_vec(self, c_perms, v): + def _mult_cmat_by_vec(self, c_perms, v): """ Multiply color matrix by vector v "on the fly". @@ -1343,7 +1344,7 @@ def _sync_signs(self, rr, c_vec): i, j = self.pairs[p] idx_mask = np.full(self.n_img, True) idx_mask[[i, j]] = False - signs[c, p, idx_mask] = self.calc_Rij_prods(c_mat_5d, i, j, c) + signs[c, p, idx_mask] = self._calc_Rij_prods(c_mat_5d, i, j, c) # Now compute the signs of Qij^c. est_signs = np.sign(np.sum(c_mat_4d, axis=(-2, -1))) @@ -1392,7 +1393,9 @@ def _sync_signs(self, rr, c_vec): for c in range(3): for r in range(self.n_img): # Image r used for signs. - c_mat_eff = self.fill_sign_sync_matrix_c(c_mat_5d_mp, sync_signs2, c, r) + c_mat_eff = self._fill_sign_sync_matrix_c( + c_mat_5d_mp, sync_signs2, c, r + ) # Construct (3*N)x(3*N) rank 1 matrices from Qik c_mat_for_svd = np.zeros( @@ -1463,8 +1466,8 @@ def _sync_signs(self, rr, c_vec): smat = la.LinearOperator( shape=(n_pairs, n_pairs), - matvec=lambda v, s=sign_mat: self.mult_smat_by_vec(v, s, pairs_map), - rmatvec=lambda v, s=sign_mat: self.mult_smat_by_vec(v, s, pairs_map), + matvec=lambda v, s=sign_mat: self._mult_smat_by_vec(v, s, pairs_map), + rmatvec=lambda v, s=sign_mat: self._mult_smat_by_vec(v, s, pairs_map), ) U, S, _ = la.svds(smat, k=3, which="LM") U = np.sign(U[0]) * U # Stable svds @@ -1520,13 +1523,13 @@ def _sync_signs(self, rr, c_vec): return rot - def fill_sign_sync_matrix_c(self, c_mat_5d_mp, sync_signs2, c, img): + def _fill_sign_sync_matrix_c(self, c_mat_5d_mp, sync_signs2, c, img): c_mat_eff = np.zeros((self.n_img, self.n_img, 3, 3), dtype=self.dtype) for r in range(self.n_img): c_mat_eff[:, r] = c_mat_5d_mp[r, sync_signs2[c, img, :, r], c] return c_mat_eff - def calc_Rij_prods(self, c_mat_5d, i, j, c): + def _calc_Rij_prods(self, c_mat_5d, i, j, c): Rik = np.delete(c_mat_5d[i, :, c], [i, j], axis=0) Rkj = np.delete(c_mat_5d[:, j, c], [i, j], axis=0) Rij = Rik @ Rkj @@ -1539,7 +1542,7 @@ def calc_Rij_prods(self, c_mat_5d, i, j, c): return np.sign(ij_signs) - def mult_smat_by_vec(self, v, sign_mat, pairs_map): + def _mult_smat_by_vec(self, v, sign_mat, pairs_map): """ Multiplies the signs sync matrix by a vector. """ @@ -1555,7 +1558,7 @@ def mult_smat_by_vec(self, v, sign_mat, pairs_map): #################### @staticmethod - def circ_seq(n1, n2, L): + def _circ_seq(n1, n2, L): """ Make a circular sequence of integers between n1 and n2 modulo L. @@ -1574,7 +1577,7 @@ def circ_seq(n1, n2, L): return seq @staticmethod - def saff_kuijlaars(N): + def _saff_kuijlaars(N): """ Generates N vertices on the unit sphere that are approximately evenly distributed. @@ -1605,7 +1608,7 @@ def saff_kuijlaars(N): return mesh @staticmethod - def mark_equators(sphere_grid, eq_filter_angle): + def _mark_equators(sphere_grid, eq_filter_angle): """ :param sphere_grid: Nx3 array of vertices in cartesian coordinates. :param eq_filter_angle: Angular distance from equator to be marked as @@ -1662,7 +1665,7 @@ def mark_equators(sphere_grid, eq_filter_angle): return eq_idx, eq_class @staticmethod - def generate_inplane_rots(sphere_grid, d_theta): + def _generate_inplane_rots(sphere_grid, d_theta): """ This function takes projection directions (points on the 2-sphere) and generates rotation matrices in SO(3). The projection direction @@ -1705,7 +1708,7 @@ def generate_inplane_rots(sphere_grid, d_theta): return inplane_rotated_grid - def generate_commonline_angles( + def _generate_commonline_angles( self, Ris, Rjs, @@ -1782,7 +1785,7 @@ def generate_commonline_angles( return cl_angles, eq2eq_Rij_table @staticmethod - def generate_commonline_indices(cl_angles): + def _generate_commonline_indices(cl_angles): # TODO: This is not accounting for n_theta other than 360! # Flatten the stack From 398ff8a9f892dd1cc4321ee14a39bd4e56b989c6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 13 May 2024 11:41:22 -0400 Subject: [PATCH 033/105] Reorganize Algo Sections. --- src/aspire/abinitio/commonline_d2.py | 886 ++++++++++++++------------- 1 file changed, 453 insertions(+), 433 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 3da513e117..4b860b86d7 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -104,6 +104,10 @@ def estimate_rotations(self): # Assign rotations. self.rotations = Ris + ######################### + # Prepare Polar Fourier # + ######################### + def _compute_shifted_pf(self): """ Pre-compute shifted and full polar Fourier transforms. @@ -133,158 +137,360 @@ def _compute_shifted_pf(self): (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) ) - def _compute_cl_scores(self): + ################################### + # Generate Commonline Lookup Data # + ################################### + + def _generate_lookup_data(self): """ - Run common lines Maximum likelihood procedure for a D2 molecule, to find - the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. + Generate candidate relative rotations and corresponding common line indices. """ - # Map the self common line scores of each 2 candidate rotations R_i,R_j to - # the respective relative rotation candidate R_i^TR_j. - n_lookup_1 = len(self.scl_idx_1) // 3 - oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) - oct2_ij_map = self.oct2_ij_map - oct2_ij_map[:, 1] += n_lookup_1 - oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) - ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) - - # Allocate output variables. - n_pairs = self.n_img * (self.n_img - 1) // 2 - corrs_idx = np.zeros(n_pairs, dtype=np.int64) - corrs_out = np.zeros(n_pairs, dtype=self.dtype) - ij_idx = 0 - - # Search for common lines between pairs of projections. - pbar = tqdm( - desc="Searching for commonlines between pairs of images", total=n_pairs + # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter + # of sphere because of D2 symmetry redundancy. + sphere_grid = self._saff_kuijlaars(self.grid_res) + octant1_mask = np.all(sphere_grid > 0, axis=1) + octant2_mask = ( + (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) ) - for i in range(self.n_img): - pf_i = self.pf_shifted[i] - scores_i = self.scls_scores[i] - - for j in range(i + 1, self.n_img): - pf_j = self.pf_full[j] + sphere_grid1 = sphere_grid[octant1_mask] + sphere_grid2 = sphere_grid[octant2_mask] - # Compute maximum correlation over all shifts. - corrs = 2 * np.real(pf_i @ np.conj(pf_j).T) - corrs = np.reshape( - corrs, (self.n_shifts, self.n_theta // 2, self.n_theta) - ) - corrs = np.max(corrs, axis=0) + # Mark Equator Directions. + # Common lines between projection directions which are perpendicular to + # symmetry axes (equator images) have common line degeneracies. Two images + # taken from directions on the same great circle which is perpendicular to + # some symmetry axis only have 2 common lines instead of 4, and must be + # treated separately. + # We detect such directions by taking a strip of radius + # eq_filter_angle about the 3 great circles perpendicular to the symmetry + # axes of D2 (i.e to X,Y and Z axes). + eq_idx1, eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) + eq_idx2, eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) - # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. - cl_idx = np.unravel_index( - self.cl_idx, (self.n_theta // 2, self.n_theta) - ) + # Mark Top View Directions. + # A Top view projection image is taken from the direction of one of the + # symmetry axes. Since all symmetry axes of D2 molecules are perpendicular + # this means that such an image is an equator with repect to both symmetry + # axes which are perpendicular to the direction of the symmetry axis from + # which the image was made, e.g. if the image was formed by projecting in + # the direction of the X (symmetry) axis, then it is an equator with + # respect to both Y and Z symmetry axes (it's direction is the + # interesection of 2 great circles perpendicular to Y and Z axes). + # Such images have severe degeneracies. A pair of Top View images (taken + # from different directions or a Top View and equator image only have a + # single common line. A top view and a regular non-equator image only have + # two common lines. - prod_corrs = corrs[cl_idx] - prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) - prod_corrs = np.prod(prod_corrs, axis=1) + # Remove top views from sphere grids and update equator indices and classes. + self.sphere_grid1 = sphere_grid1[eq_class1 < 4] + self.sphere_grid2 = sphere_grid2[eq_class2 < 4] + self.eq_idx1 = eq_idx1[eq_class1 < 4] + self.eq_idx2 = eq_idx2[eq_class2 < 4] + self.eq_idx = np.concatenate((self.eq_idx1, self.eq_idx2)) + self.eq_class1 = eq_class1[eq_class1 < 4] + self.eq_class2 = eq_class2[eq_class2 < 4] - # Incorporate scores of individual rotations from self-commonlines. - scores_j = self.scls_scores[j] - scores_ij = scores_i[ij_map[:, 0]] * scores_j[ij_map[:, 1]] + # Generate in-plane rotations for each grid point on the sphere. + self.inplane_rotated_grid1 = self._generate_inplane_rots( + self.sphere_grid1, self.inplane_res + ) + self.inplane_rotated_grid2 = self._generate_inplane_rots( + self.sphere_grid2, self.inplane_res + ) - # Find maximum correlations. - prod_corrs = prod_corrs * scores_ij - max_idx = np.argmax(prod_corrs) - corrs_idx[ij_idx] = max_idx - corrs_out[ij_idx] = prod_corrs[max_idx] - ij_idx += 1 + # Generate commmonline angles induced by all relative rotation candidates. + cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid1, + self.eq_idx1, + self.eq_idx1, + self.eq_class1, + self.eq_class1, + ) + cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid2, + self.eq_idx1, + self.eq_idx2, + self.eq_class1, + self.eq_class2, + triu=False, + ) - pbar.update() - pbar.close() + # Generate commonline indices. + self.cl_idx_1, self.cl_angles1 = self._generate_commonline_indices(cl_angles1) + self.cl_idx_2, self.cl_angles2 = self._generate_commonline_indices(cl_angles2) + self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) - # Get estimated relative viewing directions. - self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) + ######################################## + # Generate Self-Commonline Lookup Data # + ######################################## - def _get_Rijs_from_lin_idx(self, lin_idx): + def _generate_scl_lookup_data(self): """ - Restore map results from maximum-likelihood over commonlines to corresponding - relative rotations. + Generate lookup data for self-commonlines. """ - Rijs_est = np.zeros((len(lin_idx), 4, 3, 3), dtype=self.dtype) - n_cand_per_oct = len(self.cl_idx_1) // 4 - oct1_idx = lin_idx < n_cand_per_oct - n_est_in_oct1 = np.sum(oct1_idx, dtype=int) - if n_est_in_oct1 > 0: - Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) - if n_est_in_oct1 <= len(lin_idx): - Rijs_est[~oct1_idx] = self._get_Rijs_from_oct( - lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 - ) - - return Rijs_est - - def _get_Rijs_from_oct(self, lin_idx, octant=1): - if octant not in [1, 2]: - raise ValueError("`octant` must be 1 or 2.") - - # Get pairs lookup table. - if octant == 1: - unique_pairs = self.eq2eq_Rij_table_11 - else: - unique_pairs = self.eq2eq_Rij_table_12 - - n_theta = self.n_inplane_rots - n_lookup_pairs = np.sum(unique_pairs, dtype=np.int64) - n_rots = len(self.sphere_grid1) - if octant == 1: - n_rots2 = n_rots - else: - n_rots2 = len(self.sphere_grid2) - n_pairs = len(lin_idx) - - # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. - p_idx, inplane_i, inplane_j = np.unravel_index( - lin_idx, (2 * n_lookup_pairs, n_theta, n_theta // 2) + # Get self-commonline angles. + self.scl_angles1 = self._generate_scl_angles( + self.inplane_rotated_grid1, + self.eq_idx1, + self.eq_class1, ) - transpose_idx = p_idx >= n_lookup_pairs - p_idx[transpose_idx] -= n_lookup_pairs - s = self.inplane_rotated_grid1.shape - inplane_rotated_grid = np.reshape( - self.inplane_rotated_grid1, (np.prod(s[0:2]), 3, 3) + self.scl_angles2 = self._generate_scl_angles( + self.inplane_rotated_grid2, + self.eq_idx2, + self.eq_class2, ) - if octant == 1: - s2 = s - inplane_rotated_grid2 = inplane_rotated_grid - else: - s2 = self.inplane_rotated_grid2.shape - inplane_rotated_grid2 = np.reshape( - self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) - ) - Rijs_est = np.zeros((n_pairs, 4, 3, 3), dtype=self.dtype) + # Get self-commonline indices. + self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self._generate_scl_indices( + self.scl_angles1, self.eq_class1 + ) + self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self._generate_scl_indices( + self.scl_angles2, self.eq_class2 + ) + self.scl_idx_lists = np.concatenate( + (self.scl_eq_lin_idx_lists_1, self.scl_eq_lin_idx_lists_2), axis=1 + ) - # Convert linear indices of unique table to linear indices of index pairs table. - idx_vec = np.arange(np.prod(unique_pairs.shape)) - unique_lin_idx = idx_vec[unique_pairs.flatten()] - I, J = np.unravel_index(unique_lin_idx, (n_rots, n_rots2)) - est_idx = np.vstack((I[p_idx], J[p_idx])) + # Compute non-equator indices. + # Register non equator indices. Denote by C_ij the j'th in-plane rotation of + # the i'th ML candidate, and arrange all candidates in a list with their in-plane + # rotations in the order: C_11,...,C_1r,...,C_m1,...,C_mr where m is the + # number of candidates and r is the number of in plane rotations. Here we + # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are + # non equators then we have the sub list is + # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. + n_non_eq = np.sum(self.eq_class1 == 0) + np.sum(self.eq_class2 == 0) + non_eq_idx = np.zeros((n_non_eq, int(self.n_inplane_rots))) + non_eq_idx[:, 0] = ( + np.hstack( + ( + np.where(self.eq_class1 == 0)[0], + len(self.eq_class1) + np.where(self.eq_class2 == 0)[0], + ) + ) + * self.n_inplane_rots + ) + for i in range(1, self.n_inplane_rots): + non_eq_idx[:, i] = non_eq_idx[:, 0] + i - # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. - Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) - Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) - Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) - Rjs = inplane_rotated_grid2[Rjs_lin_idx] + self.non_eq_idx = non_eq_idx.astype(int) - for k, g in enumerate(self.gs): - Rijs_est[:, k] = Ris_t @ (g * Rjs) + # Non-topview equator indices. + non_tv_eq_idx = np.concatenate( + ( + np.where(self.eq_class1 > 0)[0], + len(self.eq_class1) + np.where(self.eq_class2 > 0)[0], + ) + ) - Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) + self.non_tv_eq_idx = non_tv_eq_idx.astype(int) - return Rijs_est + # Generate maps from scl indices to relative rotations. + self._generate_scl_scores_idx_map() - def _compute_scl_scores(self): - """ - Compute correlations for self-commonline candidates. + def _generate_scl_angles(self, Ris, eq_idx, eq_class): """ - n_img = self.n_img - n_theta = self.n_theta - n_eq = len(self.non_tv_eq_idx) - n_inplane = self.n_inplane_rots + Generate self-commonline angles. - # Run ML in parallel - scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) + :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). + :param eq_idx: Equator index mask for Ris. + :param eq_class: Equator classification for Ris. + """ + L = 360 # TODO: Maybe this should be self.n_theta + + # For each candidate rotation Ri we generate the set of 3 self-commonlines. + scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) + n_rots = len(Ris) + for i in range(n_rots): + Ri = Ris[i] + # TODO: Reversing self.gs here to match matlab. Should use as is. + for k, g in enumerate(self.gs[::-1][:3]): + g_Ri = g * Ri + Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri + + scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) + scl_angles[i, :, k, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) + + # Prepare self commonline coordinates. + scl_angles = scl_angles % (2 * np.pi) + + # Deal with non top view equators + # A non-TV equator has only one self common line. However, we clasify an + # equator as an image whose projection direction is at radial distance < + # eq_filter_angle from the great circle perpendicural to a symmetry axis, + # and not strictly zero distance. Thus in most cases we get 2 common lines + # differing by a small difference in degrees. Actually the calculation above + # gives us two NEARLY antipodal lines, so we first flip one of them by + # adding 180 degrees to it. Then we aggregate all the rays within the range + # between these two resulting lines to compute the score of this self common + # line for this candidate. The scoring part is done in the ML function itself. + # Furthermore, the line perpendicular to the self common line, though not + # really a self common line, has the property that all its values are real + # and both halves of the line (rays differing by pi, emanating from the + # origin) have the same values, and so it 'behaves like' a self common + # line which we also register here and exploit in the ML function. + # We put the 'real' self common line at 2 first coordinates, the + # candidate for perpendicular line is in 3rd coordinate. + + # If this is a self common line with respect to x-equator then the actual self + # common line(s) is given by the self relative rotations given by the y and z + # rotation (by 180 degrees) group members, i.e. Ri^TgyRj and Ri^TgzRj + scl_angles[eq_class == 1] = scl_angles[eq_class == 1][:, :, [1, 2, 0]] + scl_angles[eq_class == 1, :, 0] = scl_angles[eq_class == 1][:, :, 0, [1, 0]] + + # If this is a self common line with respect to y-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and z + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgzRj + scl_angles[eq_class == 2] = scl_angles[eq_class == 2][:, :, [0, 2, 1]] + scl_angles[eq_class == 2, :, 0] = scl_angles[eq_class == 2][:, :, 0, [1, 0]] + + # If this is a self common line with respect to z-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and y + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgyRj + # No need to rearrange entries, the "real" common lines are already in + # indices 1 and 2, but flip one common line to antipodal. + scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] + + # TODO: Maybe a cleaner way to do this. + # Make sure angle range is <= 180 degrees. + # p1 marks "equator" self-commonlines where both entries of the first + # scl are greater than both entries of the second scl. + p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] + p1 = p1[:, :, 0] & p1[:, :, 1] + # p2 marks "equator" self-commonlines where the angle range between the + # first and second sets of self-commonlines is greater than 180. + p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi + p2 = p2[:, :, 0] | p2[:, :, 1] + p = p1 | p2 + + # Swap entries satisfying either of the above conditions. + scl_angles[eq_class > 0] = ( + scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] + + scl_angles[eq_class > 0] * ~p[:, :, None, None] + ) + + # Convert angles from radians to degrees (indices). + scl_angles = np.round(scl_angles * 180 / np.pi) % L + + return scl_angles + + def _generate_scl_indices(self, scl_angles, eq_class): + L = 360 + + # Create candidate common line linear indices lists for equators. + # As indicated above for equator candidate, for each self common line we + # don't get a single coordinate but a range of them. Here we register a + # list of coordinates for each such self common line candidate. + non_top_view_eq_idx = np.where(eq_class > 0)[0] + n_eq = len(non_top_view_eq_idx) + n_inplane_rots = scl_angles.shape[1] + count_eq = 0 + + # eq_lin_idx_lists[1,i,j] registers a list of linear indices of the j'th + # in-plane rotation of the range for the (only) self common line of the i'th + # candidate. eq_lin_idx_lists[2,i,j] registers the actual (integer) angle + # of the self common line in the 2D Fourier space. Note that we need only + # one number since each self common line has radial coordinates of the form + # (theta, theta+180). + eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) + for i in non_top_view_eq_idx.tolist(): + for j in range(n_inplane_rots): + idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) + idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + + # Adjust so idx1 is in [0, 180) range. + geq_180 = idx1 >= 180 + idx1[geq_180] = (idx1[geq_180] - L // 2) % (L // 2) + idx2[geq_180] = (idx2[geq_180] + L // 2) % L + + # register indices in list. + eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( + (idx1, idx2), (L // 2, L) + ) + eq_lin_idx_lists[1, count_eq, j] = idx1 + count_eq += 1 + + scl_indices, _ = self._generate_commonline_indices(scl_angles) + + return scl_indices, eq_lin_idx_lists + + def _generate_scl_scores_idx_map(self): + n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) + n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) + + # First the map for i 0] + if len(unique_pairs_i) == 0: + continue + i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + + for j in unique_pairs_i: + j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) + oct2_ij_map[:, :, idx] = np.column_stack( + (i_idx_plus_offset, j_idx_plus_offset) + ) + idx += 1 + + tmp1 = oct1_ij_map[:, 0, :] + tmp2 = oct1_ij_map[:, 1, :] + self.oct1_ij_map = np.column_stack( + (tmp1.flatten(order="F"), tmp2.flatten(order="F")) + ) + + tmp1 = oct2_ij_map[:, 0, :] + tmp2 = oct2_ij_map[:, 1, :] + self.oct2_ij_map = np.column_stack( + (tmp1.flatten(order="F"), tmp2.flatten(order="F")) + ) + + ############################################## + # Compute Self-Commonline Correlation Scores # + ############################################## + + def _compute_scl_scores(self): + """ + Compute correlations for self-commonline candidates. + """ + n_img = self.n_img + n_theta = self.n_theta + n_eq = len(self.non_tv_eq_idx) + n_inplane = self.n_inplane_rots + + # Run ML in parallel + scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) M = len(scl_matrix) // 3 corrs_out = np.zeros((n_img, M), dtype=self.dtype) scl_idx = scl_matrix.reshape(M, 3) @@ -398,340 +604,154 @@ def _all_eq_measures(self, corrs): return corrs_mean * normal_corrs_max - def _generate_lookup_data(self): + ######################################### + # Compute Commonline Correlation Scores # + ######################################### + + def _compute_cl_scores(self): """ - Generate candidate relative rotations and corresponding common line indices. + Run common lines Maximum likelihood procedure for a D2 molecule, to find + the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. """ - # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter - # of sphere because of D2 symmetry redundancy. - sphere_grid = self._saff_kuijlaars(self.grid_res) - octant1_mask = np.all(sphere_grid > 0, axis=1) - octant2_mask = ( - (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) - ) - sphere_grid1 = sphere_grid[octant1_mask] - sphere_grid2 = sphere_grid[octant2_mask] + # Map the self common line scores of each 2 candidate rotations R_i,R_j to + # the respective relative rotation candidate R_i^TR_j. + n_lookup_1 = len(self.scl_idx_1) // 3 + oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) + oct2_ij_map = self.oct2_ij_map + oct2_ij_map[:, 1] += n_lookup_1 + oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) + ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) - # Mark Equator Directions. - # Common lines between projection directions which are perpendicular to - # symmetry axes (equator images) have common line degeneracies. Two images - # taken from directions on the same great circle which is perpendicular to - # some symmetry axis only have 2 common lines instead of 4, and must be - # treated separately. - # We detect such directions by taking a strip of radius - # eq_filter_angle about the 3 great circles perpendicular to the symmetry - # axes of D2 (i.e to X,Y and Z axes). - eq_idx1, eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) - eq_idx2, eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) + # Allocate output variables. + n_pairs = self.n_img * (self.n_img - 1) // 2 + corrs_idx = np.zeros(n_pairs, dtype=np.int64) + corrs_out = np.zeros(n_pairs, dtype=self.dtype) + ij_idx = 0 - # Mark Top View Directions. - # A Top view projection image is taken from the direction of one of the - # symmetry axes. Since all symmetry axes of D2 molecules are perpendicular - # this means that such an image is an equator with repect to both symmetry - # axes which are perpendicular to the direction of the symmetry axis from - # which the image was made, e.g. if the image was formed by projecting in - # the direction of the X (symmetry) axis, then it is an equator with - # respect to both Y and Z symmetry axes (it's direction is the - # interesection of 2 great circles perpendicular to Y and Z axes). - # Such images have severe degeneracies. A pair of Top View images (taken - # from different directions or a Top View and equator image only have a - # single common line. A top view and a regular non-equator image only have - # two common lines. + # Search for common lines between pairs of projections. + pbar = tqdm( + desc="Searching for commonlines between pairs of images", total=n_pairs + ) + for i in range(self.n_img): + pf_i = self.pf_shifted[i] + scores_i = self.scls_scores[i] - # Remove top views from sphere grids and update equator indices and classes. - self.sphere_grid1 = sphere_grid1[eq_class1 < 4] - self.sphere_grid2 = sphere_grid2[eq_class2 < 4] - self.eq_idx1 = eq_idx1[eq_class1 < 4] - self.eq_idx2 = eq_idx2[eq_class2 < 4] - self.eq_idx = np.concatenate((self.eq_idx1, self.eq_idx2)) - self.eq_class1 = eq_class1[eq_class1 < 4] - self.eq_class2 = eq_class2[eq_class2 < 4] - - # Generate in-plane rotations for each grid point on the sphere. - self.inplane_rotated_grid1 = self._generate_inplane_rots( - self.sphere_grid1, self.inplane_res - ) - self.inplane_rotated_grid2 = self._generate_inplane_rots( - self.sphere_grid2, self.inplane_res - ) - - # Generate commmonline angles induced by all relative rotation candidates. - cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( - self.inplane_rotated_grid1, - self.inplane_rotated_grid1, - self.eq_idx1, - self.eq_idx1, - self.eq_class1, - self.eq_class1, - ) - cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( - self.inplane_rotated_grid1, - self.inplane_rotated_grid2, - self.eq_idx1, - self.eq_idx2, - self.eq_class1, - self.eq_class2, - triu=False, - ) - - # Generate commonline indices. - self.cl_idx_1, self.cl_angles1 = self._generate_commonline_indices(cl_angles1) - self.cl_idx_2, self.cl_angles2 = self._generate_commonline_indices(cl_angles2) - self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) - - def _generate_scl_lookup_data(self): - """ - Generate lookup data for self-commonlines. - """ - # Get self-commonline angles. - self.scl_angles1 = self._generate_scl_angles( - self.inplane_rotated_grid1, - self.eq_idx1, - self.eq_class1, - ) - self.scl_angles2 = self._generate_scl_angles( - self.inplane_rotated_grid2, - self.eq_idx2, - self.eq_class2, - ) + for j in range(i + 1, self.n_img): + pf_j = self.pf_full[j] - # Get self-commonline indices. - self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self._generate_scl_indices( - self.scl_angles1, self.eq_class1 - ) - self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self._generate_scl_indices( - self.scl_angles2, self.eq_class2 - ) - self.scl_idx_lists = np.concatenate( - (self.scl_eq_lin_idx_lists_1, self.scl_eq_lin_idx_lists_2), axis=1 - ) + # Compute maximum correlation over all shifts. + corrs = 2 * np.real(pf_i @ np.conj(pf_j).T) + corrs = np.reshape( + corrs, (self.n_shifts, self.n_theta // 2, self.n_theta) + ) + corrs = np.max(corrs, axis=0) - # Compute non-equator indices. - # Register non equator indices. Denote by C_ij the j'th in-plane rotation of - # the i'th ML candidate, and arrange all candidates in a list with their in-plane - # rotations in the order: C_11,...,C_1r,...,C_m1,...,C_mr where m is the - # number of candidates and r is the number of in plane rotations. Here we - # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are - # non equators then we have the sub list is - # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. - n_non_eq = np.sum(self.eq_class1 == 0) + np.sum(self.eq_class2 == 0) - non_eq_idx = np.zeros((n_non_eq, int(self.n_inplane_rots))) - non_eq_idx[:, 0] = ( - np.hstack( - ( - np.where(self.eq_class1 == 0)[0], - len(self.eq_class1) + np.where(self.eq_class2 == 0)[0], + # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. + cl_idx = np.unravel_index( + self.cl_idx, (self.n_theta // 2, self.n_theta) ) - ) - * self.n_inplane_rots - ) - for i in range(1, self.n_inplane_rots): - non_eq_idx[:, i] = non_eq_idx[:, 0] + i - self.non_eq_idx = non_eq_idx.astype(int) + prod_corrs = corrs[cl_idx] + prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) + prod_corrs = np.prod(prod_corrs, axis=1) - # Non-topview equator indices. - non_tv_eq_idx = np.concatenate( - ( - np.where(self.eq_class1 > 0)[0], - len(self.eq_class1) + np.where(self.eq_class2 > 0)[0], - ) - ) + # Incorporate scores of individual rotations from self-commonlines. + scores_j = self.scls_scores[j] + scores_ij = scores_i[ij_map[:, 0]] * scores_j[ij_map[:, 1]] - self.non_tv_eq_idx = non_tv_eq_idx.astype(int) + # Find maximum correlations. + prod_corrs = prod_corrs * scores_ij + max_idx = np.argmax(prod_corrs) + corrs_idx[ij_idx] = max_idx + corrs_out[ij_idx] = prod_corrs[max_idx] + ij_idx += 1 - # Generate maps from scl indices to relative rotations. - self._generate_scl_scores_idx_map() + pbar.update() + pbar.close() - def _generate_scl_angles(self, Ris, eq_idx, eq_class): - """ - Generate self-commonline angles. + # Get estimated relative viewing directions. + self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) - :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). - :param eq_idx: Equator index mask for Ris. - :param eq_class: Equator classification for Ris. + def _get_Rijs_from_lin_idx(self, lin_idx): """ - L = 360 # TODO: Maybe this should be self.n_theta - - # For each candidate rotation Ri we generate the set of 3 self-commonlines. - scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) - n_rots = len(Ris) - for i in range(n_rots): - Ri = Ris[i] - # TODO: Reversing self.gs here to match matlab. Should use as is. - for k, g in enumerate(self.gs[::-1][:3]): - g_Ri = g * Ri - Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri - - scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) - scl_angles[i, :, k, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) - - # Prepare self commonline coordinates. - scl_angles = scl_angles % (2 * np.pi) - - # Deal with non top view equators - # A non-TV equator has only one self common line. However, we clasify an - # equator as an image whose projection direction is at radial distance < - # eq_filter_angle from the great circle perpendicural to a symmetry axis, - # and not strictly zero distance. Thus in most cases we get 2 common lines - # differing by a small difference in degrees. Actually the calculation above - # gives us two NEARLY antipodal lines, so we first flip one of them by - # adding 180 degrees to it. Then we aggregate all the rays within the range - # between these two resulting lines to compute the score of this self common - # line for this candidate. The scoring part is done in the ML function itself. - # Furthermore, the line perpendicular to the self common line, though not - # really a self common line, has the property that all its values are real - # and both halves of the line (rays differing by pi, emanating from the - # origin) have the same values, and so it 'behaves like' a self common - # line which we also register here and exploit in the ML function. - # We put the 'real' self common line at 2 first coordinates, the - # candidate for perpendicular line is in 3rd coordinate. + Restore map results from maximum-likelihood over commonlines to corresponding + relative rotations. + """ + Rijs_est = np.zeros((len(lin_idx), 4, 3, 3), dtype=self.dtype) + n_cand_per_oct = len(self.cl_idx_1) // 4 + oct1_idx = lin_idx < n_cand_per_oct + n_est_in_oct1 = np.sum(oct1_idx, dtype=int) + if n_est_in_oct1 > 0: + Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) + if n_est_in_oct1 <= len(lin_idx): + Rijs_est[~oct1_idx] = self._get_Rijs_from_oct( + lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 + ) - # If this is a self common line with respect to x-equator then the actual self - # common line(s) is given by the self relative rotations given by the y and z - # rotation (by 180 degrees) group members, i.e. Ri^TgyRj and Ri^TgzRj - scl_angles[eq_class == 1] = scl_angles[eq_class == 1][:, :, [1, 2, 0]] - scl_angles[eq_class == 1, :, 0] = scl_angles[eq_class == 1][:, :, 0, [1, 0]] + return Rijs_est - # If this is a self common line with respect to y-equator then the actual self - # common line(s) is given by the self relative rotations given by the x and z - # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgzRj - scl_angles[eq_class == 2] = scl_angles[eq_class == 2][:, :, [0, 2, 1]] - scl_angles[eq_class == 2, :, 0] = scl_angles[eq_class == 2][:, :, 0, [1, 0]] + def _get_Rijs_from_oct(self, lin_idx, octant=1): + if octant not in [1, 2]: + raise ValueError("`octant` must be 1 or 2.") - # If this is a self common line with respect to z-equator then the actual self - # common line(s) is given by the self relative rotations given by the x and y - # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgyRj - # No need to rearrange entries, the "real" common lines are already in - # indices 1 and 2, but flip one common line to antipodal. - scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] + # Get pairs lookup table. + if octant == 1: + unique_pairs = self.eq2eq_Rij_table_11 + else: + unique_pairs = self.eq2eq_Rij_table_12 - # TODO: Maybe a cleaner way to do this. - # Make sure angle range is <= 180 degrees. - # p1 marks "equator" self-commonlines where both entries of the first - # scl are greater than both entries of the second scl. - p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] - p1 = p1[:, :, 0] & p1[:, :, 1] - # p2 marks "equator" self-commonlines where the angle range between the - # first and second sets of self-commonlines is greater than 180. - p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi - p2 = p2[:, :, 0] | p2[:, :, 1] - p = p1 | p2 + n_theta = self.n_inplane_rots + n_lookup_pairs = np.sum(unique_pairs, dtype=np.int64) + n_rots = len(self.sphere_grid1) + if octant == 1: + n_rots2 = n_rots + else: + n_rots2 = len(self.sphere_grid2) + n_pairs = len(lin_idx) - # Swap entries satisfying either of the above conditions. - scl_angles[eq_class > 0] = ( - scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] - + scl_angles[eq_class > 0] * ~p[:, :, None, None] + # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. + p_idx, inplane_i, inplane_j = np.unravel_index( + lin_idx, (2 * n_lookup_pairs, n_theta, n_theta // 2) ) - - # Convert angles from radians to degrees (indices). - scl_angles = np.round(scl_angles * 180 / np.pi) % L - - return scl_angles - - def _generate_scl_indices(self, scl_angles, eq_class): - L = 360 - - # Create candidate common line linear indices lists for equators. - # As indicated above for equator candidate, for each self common line we - # don't get a single coordinate but a range of them. Here we register a - # list of coordinates for each such self common line candidate. - non_top_view_eq_idx = np.where(eq_class > 0)[0] - n_eq = len(non_top_view_eq_idx) - n_inplane_rots = scl_angles.shape[1] - count_eq = 0 - - # eq_lin_idx_lists[1,i,j] registers a list of linear indices of the j'th - # in-plane rotation of the range for the (only) self common line of the i'th - # candidate. eq_lin_idx_lists[2,i,j] registers the actual (integer) angle - # of the self common line in the 2D Fourier space. Note that we need only - # one number since each self common line has radial coordinates of the form - # (theta, theta+180). - eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) - for i in non_top_view_eq_idx.tolist(): - for j in range(n_inplane_rots): - idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) - idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) - - # Adjust so idx1 is in [0, 180) range. - geq_180 = idx1 >= 180 - idx1[geq_180] = (idx1[geq_180] - L // 2) % (L // 2) - idx2[geq_180] = (idx2[geq_180] + L // 2) % L - - # register indices in list. - eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( - (idx1, idx2), (L // 2, L) - ) - eq_lin_idx_lists[1, count_eq, j] = idx1 - count_eq += 1 - - scl_indices, _ = self._generate_commonline_indices(scl_angles) - - return scl_indices, eq_lin_idx_lists - - def _generate_scl_scores_idx_map(self): - n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) - n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) - - # First the map for i= n_lookup_pairs + p_idx[transpose_idx] -= n_lookup_pairs + s = self.inplane_rotated_grid1.shape + inplane_rotated_grid = np.reshape( + self.inplane_rotated_grid1, (np.prod(s[0:2]), 3, 3) ) - i_idx = np.repeat(np.arange(self.n_inplane_rots), self.n_inplane_rots // 2) - j_idx = np.tile(np.arange(self.n_inplane_rots // 2), self.n_inplane_rots) - idx_vec = np.arange(n_rot_1) - idx = 0 - - for i in range(n_rot_1): - unique_pairs_i = idx_vec[self.eq2eq_Rij_table_11[i]] - if len(unique_pairs_i) == 0: - continue - i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + if octant == 1: + s2 = s + inplane_rotated_grid2 = inplane_rotated_grid + else: + s2 = self.inplane_rotated_grid2.shape + inplane_rotated_grid2 = np.reshape( + self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) + ) - for j in unique_pairs_i: - j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) - oct1_ij_map[:, :, idx] = np.column_stack( - (i_idx_plus_offset, j_idx_plus_offset) - ) - idx += 1 + Rijs_est = np.zeros((n_pairs, 4, 3, 3), dtype=self.dtype) - # First the map for i 0] - if len(unique_pairs_i) == 0: - continue - i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. + Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) + Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) + Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) + Rjs = inplane_rotated_grid2[Rjs_lin_idx] - for j in unique_pairs_i: - j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) - oct2_ij_map[:, :, idx] = np.column_stack( - (i_idx_plus_offset, j_idx_plus_offset) - ) - idx += 1 + for k, g in enumerate(self.gs): + Rijs_est[:, k] = Ris_t @ (g * Rjs) - tmp1 = oct1_ij_map[:, 0, :] - tmp2 = oct1_ij_map[:, 1, :] - self.oct1_ij_map = np.column_stack( - (tmp1.flatten(order="F"), tmp2.flatten(order="F")) - ) + Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) - tmp1 = oct2_ij_map[:, 0, :] - tmp2 = oct2_ij_map[:, 1, :] - self.oct2_ij_map = np.column_stack( - (tmp1.flatten(order="F"), tmp2.flatten(order="F")) - ) + return Rijs_est - ############################# - # Methods for Global J Sync # - ############################# + #################################### + # Perform Global J Synchronization # + #################################### def _global_J_sync(self, Rijs): """ From 8f142c753c8d7fa063c518a94611db291d00809b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 14 May 2024 11:51:48 -0400 Subject: [PATCH 034/105] Test for global_J_sync. --- src/aspire/abinitio/commonline_d2.py | 25 +++++------ tests/test_orient_d2.py | 65 +++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 14 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 4b860b86d7..60e6553954 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -72,6 +72,7 @@ def __init__( self._generate_gs() self.triplets = all_triplets(self.n_img) self.pairs, self.pairs_to_linear = all_pairs(self.n_img, return_map=True) + self.n_pairs = len(self.pairs) def estimate_rotations(self): """ @@ -128,7 +129,7 @@ def _compute_shifted_pf(self): pf *= ( np.sqrt(2) / 2 ) # Magic number to match matlab pf. (root2 over 2) Remove after debug. - pf = pf[:, :, ::-1] # also to match matlab + pf = pf[:, :, ::-1] # also to match matlab. Can remove. self.pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. @@ -791,8 +792,6 @@ def _J_configuration(self, Rijs): :return: List of n-choose-3 indices in {0,1,2,3} indicating which J-configuration for each triplet of Rijs, i, where the rows and columns of S are indexed by # double indexes (i,j), 1<=i Date: Tue, 14 May 2024 14:24:38 -0400 Subject: [PATCH 035/105] Fix J-sync test logic. --- tests/test_orient_d2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index f8177ca52f..933ce01091 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -98,7 +98,10 @@ def orient_est(source): def test_estimate_rotations(orient_est): """ This test runs through the complete D2 algorithm and compares the - estimated rotations to the ground truth rotations. + estimated rotations to the ground truth rotations. In particular, + we check that the estimates are close to the ground truth up to + a local rotation by a D2 symmetry group member, a global J-conjugation, + and a globally aligning rotation. """ # Estimate rotations. orient_est.estimate_rotations() @@ -170,7 +173,7 @@ def test_global_J_sync(orient_est): # Perform global J-synchronization and check that # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). Rijs_sync = orient_est._global_J_sync(Rijs_conj) - need_to_conj_Rijs = ~np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) if need_to_conj_Rijs: np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) else: From 8bdeb53c02fe1800010397be8a3485f22bee1864 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 17 May 2024 12:06:28 -0400 Subject: [PATCH 036/105] Seed initial vec for scipy eigs. single triplet J-sync test. Color sync test. --- src/aspire/abinitio/commonline_d2.py | 14 ++-- tests/test_orient_d2.py | 98 +++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 60e6553954..eebcf0e0a7 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -126,9 +126,7 @@ def _compute_shifted_pf(self): # Reconstruct full polar Fourier for use in correlation. pf[:, :, 0] = 0 # Matching matlab convention to zero out the lowest frequency. pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. - pf *= ( - np.sqrt(2) / 2 - ) # Magic number to match matlab pf. (root2 over 2) Remove after debug. + pf *= np.sqrt(2) / 2 # Magic number to match matlab pf. Remove after debug. pf = pf[:, :, ::-1] # also to match matlab. Can remove. self.pf_full = PolarFT.half_to_full(pf) @@ -977,7 +975,11 @@ def _sync_colors(self, Rijs): color_mat = la.LinearOperator( (3 * n_pairs,) * 2, lambda v: self._mult_cmat_by_vec(color_perms, v) ) - vals, colors = la.eigs(color_mat, k=3, which="LR") + v0 = randn( + 3 * n_pairs, seed=self.seed + ) # Seed eigs initial vector for iterative method + v0 = v0 / norm(v0) + vals, colors = la.eigs(color_mat, k=3, which="LM", v0=v0) # Changed from "LR" vals = np.real(vals) colors = np.real(colors) colors = np.sign(colors[0]) * colors # Stable eigs @@ -1269,7 +1271,7 @@ def R_theta(theta): else: colors[i] = p_i_sqr colors = colors.flatten() - colors = 2 - colors # For debug. remove + # colors = 2 - colors # For debug. remove return colors, best_unmix ##################### @@ -1283,7 +1285,7 @@ def _sync_signs(self, rr, c_vec): and the matrices Ri are assembled. """ # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according - # to the color partition established in color synchroniztion procedure. + # to the color partition established in color synchronization procedure. # The partition is stored in two different arrays each with the purpose # of a computational speed up for two different computations performed # later (space considerations are of little concern since arrays are ~ diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 933ce01091..079acc8a25 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -3,7 +3,13 @@ from aspire.abinitio import CLSymmetryD2 from aspire.source import Simulation -from aspire.utils import J_conjugate, all_pairs, mean_aligned_angular_distance +from aspire.utils import ( + J_conjugate, + Rotation, + all_pairs, + mean_aligned_angular_distance, + utest_tolerance, +) from aspire.volume import DnSymmetricVolume, DnSymmetryGroup ############## @@ -138,7 +144,9 @@ def test_global_J_sync(orient_est): # J-conjugate a random set of Rijs. Rijs_conj = Rijs.copy() - inds = np.random.choice(orient_est.n_pairs, size=15, replace=False) + inds = np.random.choice( + orient_est.n_pairs, size=orient_est.n_pairs // 2, replace=False + ) Rijs_conj[inds] = J_conjugate(Rijs[inds]) # Create J-configuration conditions for the triplet Rij, Rjk, Rik. @@ -180,6 +188,92 @@ def test_global_J_sync(orient_est): np.testing.assert_allclose(Rijs_sync, Rijs) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_global_J_sync_single_triplet(dtype): + """ + This exercises the J-synchronization algorithm using the smallest + possible problem size, a single triplets of relative rotations Rijs. + """ + # Generate 3 image source and orientation object. + src = Simulation(n=3, L=10, dtype=dtype, seed=SEED) + orient_est = CLSymmetryD2(src, n_theta=360, seed=SEED) + + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for k, g in enumerate(orient_est.gs): + k = (k + p) % 4 # Mix up the ordering of symmetric Rijs + Rijs[p, k] = rots[i].T @ (g * rots[j]) + + # J-conjugate a random Rij. + Rijs_conj = Rijs.copy() + inds = np.random.choice(orient_est.n_pairs, size=1, replace=False) + Rijs_conj[inds] = J_conjugate(Rijs[inds]) + + # Perform global J-synchronization and check that + # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). + Rijs_sync = orient_est._global_J_sync(Rijs_conj) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + if need_to_conj_Rijs: + np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) + else: + np.testing.assert_allclose(Rijs_sync, Rijs) + + +def test_sync_colors(orient_est): + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for k, g in enumerate(orient_est.gs): + k = (k + p) % 4 # Mix up the ordering of symmetric Rijs + Rijs[p, k] = rots[i].T @ (g * rots[j]) + + # Perform color synchronization. + colors, Rijs_rows = orient_est._sync_colors(Rijs) + + # Rijs_rows is shape (n_pairs, 3, 3, 3) where Rijs_rows[ij, m] corresponds + # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row + # of the rotations matrices Ri and Rj. `colors` partitions the set of Rijs_rows + # such that the indices of `colors` corresponds to the row index m. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + vijs[p, m] = np.outer(rots[i][m], rots[j][m]) + + # Reshape `colors` to shape (n_pairs, 3) and use to index Rijs_rows into the + # correctly order 3rd row outer products vijs. + colors = colors.reshape(orient_est.n_pairs, 3) + + # `colors` is an arbitrary permutation (but globally consistent), and we know + # that colors[0] should correspond to the ordering [0, 1, 2] due to the construction + # of Rijs[0] using the symmetric rotations g0, g1, g2, g3 in non-permuted order. + # So we sort the columns such that colors[0] = [0,1,2]. + + # Create a mapping array + perm = colors[0] + mapping = np.zeros_like(perm) + mapping[perm] = np.arange(3) + + # Apply this mapping to all rows of the colors array + colors_mapped = mapping[colors] + + # Synchronize Rijs_rows according to the color map. + row_indices = np.arange(orient_est.n_pairs)[:, None] + Rijs_rows_synced = Rijs_rows[row_indices, colors_mapped] + + # Rijs_rows_synced should match the ground truth vijs up to the sign of each row. + # So we multiply by the sign of the first column of the last two axes to sync signs. + vijs = vijs * np.sign(vijs[..., :, 0])[..., np.newaxis] + Rijs_rows_synced = ( + Rijs_rows_synced * np.sign(Rijs_rows_synced[..., :, 0])[..., np.newaxis] + ) + np.testing.assert_allclose( + vijs, Rijs_rows_synced, atol=utest_tolerance(orient_est.dtype) + ) + + #################### # Helper Functions # #################### From 27c6103b0ca9001fe8db34e4d0f4aa3914e382ee Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 17 May 2024 12:10:00 -0400 Subject: [PATCH 037/105] unused import --- tests/test_orient_d2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 079acc8a25..c640f7262a 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -5,7 +5,6 @@ from aspire.source import Simulation from aspire.utils import ( J_conjugate, - Rotation, all_pairs, mean_aligned_angular_distance, utest_tolerance, From ad7c328bcb3970717870eb0a96d96ee75b386e31 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 17 May 2024 15:57:43 -0400 Subject: [PATCH 038/105] test for sign sync. --- tests/test_orient_d2.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index c640f7262a..2270ee47a1 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -230,12 +230,13 @@ def test_sync_colors(orient_est): Rijs[p, k] = rots[i].T @ (g * rots[j]) # Perform color synchronization. - colors, Rijs_rows = orient_est._sync_colors(Rijs) - # Rijs_rows is shape (n_pairs, 3, 3, 3) where Rijs_rows[ij, m] corresponds # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row # of the rotations matrices Ri and Rj. `colors` partitions the set of Rijs_rows # such that the indices of `colors` corresponds to the row index m. + colors, Rijs_rows = orient_est._sync_colors(Rijs) + + # Compute ground truth m'th row outer products. vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): for m in range(3): @@ -273,6 +274,37 @@ def test_sync_colors(orient_est): ) +def test_sync_signs(orient_est): + """ + Sign synchronization consumes a set of m'th row outer products along with + a color synchronizing vector and returns a set of rotation matrices + that are the result of synchronizing the signs of the rows of the outer + products and factoring the outer products to form the rows of the rotations. + + In this test we provide a color-synchronized set of m'th row outer products + with a corresponding color vector and test that the output rotations + equivalent to the ground truth rotations up to a global alignment. + """ + rots = orient_est.src.rotations + + # Compute ground truth m'th row outer products. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + vijs[p, m] = np.outer(rots[i][m], rots[j][m]) + + # We will pass in m'th row outer products that are color synchronized, + # ie. colors = [0, 1, 2, 0, 1, 2, ...] + perm = np.array([0, 2, 1]) + colors = np.tile(perm, orient_est.n_pairs) + + # Estimate rotations and check against ground truth. + rots_est = orient_est._sync_signs(vijs, colors) + mean_aligned_angular_distance( + rots, rots_est, degree_tol=utest_tolerance(orient_est.dtype) + ) + + #################### # Helper Functions # #################### From c1fdd6bf5cc41216a4209879de7fdf7643ed9902 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 20 May 2024 09:02:03 -0400 Subject: [PATCH 039/105] adjust angular distance tol. --- tests/test_orient_d2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 2270ee47a1..6f0b524ce5 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -300,9 +300,7 @@ def test_sync_signs(orient_est): # Estimate rotations and check against ground truth. rots_est = orient_est._sync_signs(vijs, colors) - mean_aligned_angular_distance( - rots, rots_est, degree_tol=utest_tolerance(orient_est.dtype) - ) + mean_aligned_angular_distance(rots, rots_est, degree_tol=1e-5) #################### From f178bd1123628df5ae5b1f750ed83950c9ca6d4e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 20 May 2024 14:35:28 -0400 Subject: [PATCH 040/105] Test self-commonline score. --- src/aspire/abinitio/commonline_d2.py | 1 + tests/test_orient_d2.py | 117 ++++++++++++++++++++------- 2 files changed, 90 insertions(+), 28 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index eebcf0e0a7..f8e8ee57fc 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -669,6 +669,7 @@ def _compute_cl_scores(self): pbar.close() # Get estimated relative viewing directions. + self.corrs_idx = corrs_idx # Used in unit test self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) def _get_Rijs_from_lin_idx(self, lin_idx): diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 6f0b524ce5..64889db2e7 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -5,6 +5,7 @@ from aspire.source import Simulation from aspire.utils import ( J_conjugate, + Rotation, all_pairs, mean_aligned_angular_distance, utest_tolerance, @@ -27,22 +28,22 @@ SEED = 3 -@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}") +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") def dtype(request): return request.param -@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") def resolution(request): return request.param -@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}") +@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}", scope="module") def n_img(request): return request.param -@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") def offsets(request): return request.param @@ -52,7 +53,7 @@ def offsets(request): ############ -@pytest.fixture +@pytest.fixture(scope="module") def source(n_img, resolution, dtype, offsets): vol = DnSymmetricVolume( L=resolution, order=2, C=1, K=100, dtype=dtype, seed=SEED @@ -70,29 +71,9 @@ def source(n_img, resolution, dtype, offsets): return src -@pytest.fixture +@pytest.fixture(scope="module") def orient_est(source): - # Search for common lines over less shifts for 0 offsets. - max_shift = 0 - shift_step = 1 - if source.offsets.all() != 0: - max_shift = 0.2 - shift_step = 0.1 # Reduce shift steps for non-integer offsets of Simulation. - - orient_est = CLSymmetryD2( - source, - max_shift=max_shift, - shift_step=shift_step, - n_theta=360, - n_rad=source.L, - grid_res=350, # Tuned for speed - inplane_res=15, # Tuned for speed - eq_min_dist=10, # Tuned for speed - epsilon=0.001, - seed=SEED, - ) - - return orient_est + return build_CL_from_source(source) ######### @@ -126,6 +107,63 @@ def test_estimate_rotations(orient_est): mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=deg_tol) +def test_scl_scores(orient_est): + + # Generate lookup data and extract rotations from the candidate `sphere_grid`. + orient_est._generate_lookup_data() + cand_rots = orient_est.inplane_rotated_grid1 + non_eq_idx = int( + np.argwhere(orient_est.eq_class1 == 0)[0] + ) # Take first non equator viewing direction + rots = cand_rots[ + non_eq_idx, :10 + ] # Take the first 10 inplane rots from non_eq viewing direction. + angles = Rotation(rots).angles + + # Create a Simulation using those first 10 candidate rotations. + vol = DnSymmetricVolume( + L=orient_est.src.L, order=2, C=1, K=100, dtype=orient_est.dtype, seed=SEED + ).generate() + + src = Simulation( + n=orient_est.src.n, + L=orient_est.src.L, + vols=vol, + angles=angles, + offsets=orient_est.src.offsets, + amplitudes=1, + seed=SEED, + ) + + # Initialize CL instance with new source. + CL = build_CL_from_source(src) + + # Generate lookup data and compute scl scores. + # Pre-compute phase-shifted polar Fourier. + CL._compute_shifted_pf() + + # Generate lookup data + CL._generate_lookup_data() + CL._generate_scl_lookup_data() + + # Compute self-commonline scores. + CL._compute_scl_scores() + + # CL.scls_scores is shape (n_img, n_cand_rots). Since we used the first + # 10 candidate rotations of the first non-equator viewing direction as our + # Simulation rotations, the maximum correlation for image i should occur at + # candidate rotation index (non_eq_idx * CL.n_inplane_rots + i). + max_corr_idx = np.argmax(CL.scls_scores, axis=1) + gt_idx = CL.n_inplane_rots * non_eq_idx + np.arange(10) + + # Check that self-commonline indices match ground truth. + n_match = np.sum(max_corr_idx == gt_idx) + match_tol = 0.99 # match at least 99%. + if not (src.offsets == 0.0).all(): + match_tol = 0.89 # match at least 89% with offsets. + np.testing.assert_array_less(match_tol, n_match / src.n) + + def test_global_J_sync(orient_est): """ For this test we build a set of relative rotations, Rijs, of shape @@ -295,7 +333,7 @@ def test_sync_signs(orient_est): # We will pass in m'th row outer products that are color synchronized, # ie. colors = [0, 1, 2, 0, 1, 2, ...] - perm = np.array([0, 2, 1]) + perm = np.array([0, 1, 2]) colors = np.tile(perm, orient_est.n_pairs) # Estimate rotations and check against ground truth. @@ -374,3 +412,26 @@ def g_sync_d2(rots, rots_gt): rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt return rots_gt_sync + + +def build_CL_from_source(source): + # Search for common lines over less shifts for 0 offsets. + max_shift = 0 + shift_step = 1 + if source.offsets.all() != 0: + max_shift = 0.2 + shift_step = 0.1 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSymmetryD2( + source, + max_shift=max_shift, + shift_step=shift_step, + n_theta=360, + n_rad=source.L, + grid_res=350, # Tuned for speed + inplane_res=15, # Tuned for speed + eq_min_dist=10, # Tuned for speed + epsilon=0.001, + seed=SEED, + ) + return orient_est From b8f86233901521bdff524c4724746a712e056471 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 23 May 2024 13:33:33 -0400 Subject: [PATCH 041/105] Refactor to use n_theta other than 360. --- src/aspire/abinitio/commonline_d2.py | 108 +++++++++++++++------------ tests/test_orient_d2.py | 47 +++++------- 2 files changed, 79 insertions(+), 76 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f8e8ee57fc..5e6b0c543d 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -297,7 +297,6 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): :param eq_idx: Equator index mask for Ris. :param eq_class: Equator classification for Ris. """ - L = 360 # TODO: Maybe this should be self.n_theta # For each candidate rotation Ri we generate the set of 3 self-commonlines. scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) @@ -370,13 +369,15 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): + scl_angles[eq_class > 0] * ~p[:, :, None, None] ) - # Convert angles from radians to degrees (indices). - scl_angles = np.round(scl_angles * 180 / np.pi) % L - - return scl_angles + # Convert from angles [0,2*pi) to degrees [0, 360). + return np.round(scl_angles * 180 / np.pi) % 360 def _generate_scl_indices(self, scl_angles, eq_class): - L = 360 + L = self.n_theta + + # Convert from angles to indices. + scl_indices, _ = self._generate_commonline_indices(scl_angles) + scl_angles = np.mod(np.round(scl_angles / (2 * np.pi) * L), L).astype(int) # Create candidate common line linear indices lists for equators. # As indicated above for equator candidate, for each self common line we @@ -399,10 +400,15 @@ def _generate_scl_indices(self, scl_angles, eq_class): idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + # Ensure idx1 and idx2 have same number of elements. + # Might be off by one due to n_theta discretization. + end = np.minimum(len(idx1), len(idx2)) + idx1, idx2 = idx1[:end], idx2[:end] + # Adjust so idx1 is in [0, 180) range. - geq_180 = idx1 >= 180 - idx1[geq_180] = (idx1[geq_180] - L // 2) % (L // 2) - idx2[geq_180] = (idx2[geq_180] + L // 2) % L + is_geq_than_pi = idx1 >= L // 2 + idx1[is_geq_than_pi] = (idx1[is_geq_than_pi] - L // 2) % (L // 2) + idx2[is_geq_than_pi] = (idx2[is_geq_than_pi] + L // 2) % L # register indices in list. eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( @@ -411,8 +417,6 @@ def _generate_scl_indices(self, scl_angles, eq_class): eq_lin_idx_lists[1, count_eq, j] = idx1 count_eq += 1 - scl_indices, _ = self._generate_commonline_indices(scl_angles) - return scl_indices, eq_lin_idx_lists def _generate_scl_scores_idx_map(self): @@ -543,10 +547,10 @@ def _compute_scl_scores(self): def _all_eq_measures(self, corrs): """ - Compute a measure of how much an image from data is close to be an equator. + Compute a measure of how much an image from data is close to an equator. """ # First compute the eq measure (corrs(scl-k,scl+k) for k=1:90) - # An eqautor image of a D2 molecule has the following property: If t_i is + # An equator image of a D2 molecule has the following property: If t_i is # the angle of one of the rays of the self common line then all the pairs of # rays of the form (t_i-k,t_i+k) for k=1:90 are identical. For each t_i we # average over correlations between the lines (t_i-k,t_i+k) for k=1:90 @@ -554,24 +558,28 @@ def _all_eq_measures(self, corrs): # with angle t_i is a self common line. # (This first loop can be done once outside this function and then pass # idx as an argument). - idx = np.zeros((180, 90, 2)) - idx_1 = np.mod(np.vstack((-np.arange(1, 91), np.arange(1, 91))), 360) + L = self.n_theta + idx = np.zeros((L // 2, L // 4, 2)) + idx_1 = np.mod( + np.vstack((-np.arange(1, L // 4 + 1), np.arange(1, L // 4 + 1))), L + ) idx[0, :, :] = idx_1.T - for k in range(1, 180): - idx[k, :, :] = np.mod(idx_1.T + k, 360) - idx = np.mod(idx, 360) + for k in range(1, L // 2): + idx[k, :, :] = np.mod(idx_1.T + k, L) + idx = np.mod(idx, L) + # Convert to Fourier ray indices. idx_1 = idx[:, :, 0].flatten() idx_2 = idx[:, :, 1].flatten() # Make all Ri coordinates < 180 and compute linear indices for corrrelations - bigger_than_180 = idx_1 >= 180 - idx_1[bigger_than_180] = idx_1[bigger_than_180] - 180 - idx_2[bigger_than_180] = (idx_2[bigger_than_180] + 180) % 360 + is_geq_than_pi = idx_1 >= L // 2 + idx_1[is_geq_than_pi] = idx_1[is_geq_than_pi] - (L // 2) + idx_2[is_geq_than_pi] = (idx_2[is_geq_than_pi] + (L // 2)) % L # Compute correlations. eq_corrs = corrs[idx_1.astype(int), idx_2.astype(int)] - eq_corrs = eq_corrs.reshape(180, 90) + eq_corrs = eq_corrs.reshape(L // 2, L // 4) corrs_mean = np.mean(eq_corrs, axis=1) # Now compute correlations for normals to scls. @@ -581,24 +589,24 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = 2 + r = (2 * L) // 360 - normal_2_scl_idx = np.zeros((180, 2 * r + 1)) - normal_2_scl_idx_1 = np.mod(180 - np.arange(90 - r, 90 + r + 1), 360) + normal_2_scl_idx = np.zeros((L // 2, 2 * r + 1)) + normal_2_scl_idx_1 = np.mod(L // 2 - np.arange(L // 4 - r, L // 4 + r + 1), L) normal_2_scl_idx[0, :] = normal_2_scl_idx_1 - for k in range(1, 180): - normal_2_scl_idx[k, :] = np.mod(normal_2_scl_idx_1 + k, 360) + for k in range(1, L // 2): + normal_2_scl_idx[k, :] = np.mod(normal_2_scl_idx_1 + k, L) # Make all Ri coordinates <=180 and compute linear indices for corrrelations - bigger_than_180 = normal_2_scl_idx >= 180 - normal_2_scl_idx[bigger_than_180] = normal_2_scl_idx[bigger_than_180] - 180 + is_geq_than_pi = normal_2_scl_idx >= L // 2 + normal_2_scl_idx[is_geq_than_pi] = normal_2_scl_idx[is_geq_than_pi] - (L // 2) # Compute correlations for normals. normal_2_scl_idx = normal_2_scl_idx.flatten() normal_corrs = corrs[ - normal_2_scl_idx.astype(int), normal_2_scl_idx.astype(int) + 180 + normal_2_scl_idx.astype(int), normal_2_scl_idx.astype(int) + (L // 2) ] - normal_corrs = normal_corrs.reshape(180, 2 * r + 1) + normal_corrs = normal_corrs.reshape(L // 2, 2 * r + 1) normal_corrs_max = np.max(normal_corrs, axis=1) return corrs_mean * normal_corrs_max @@ -612,6 +620,8 @@ def _compute_cl_scores(self): Run common lines Maximum likelihood procedure for a D2 molecule, to find the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. """ + L = self.n_theta + # Map the self common line scores of each 2 candidate rotations R_i,R_j to # the respective relative rotation candidate R_i^TR_j. n_lookup_1 = len(self.scl_idx_1) // 3 @@ -640,15 +650,11 @@ def _compute_cl_scores(self): # Compute maximum correlation over all shifts. corrs = 2 * np.real(pf_i @ np.conj(pf_j).T) - corrs = np.reshape( - corrs, (self.n_shifts, self.n_theta // 2, self.n_theta) - ) + corrs = np.reshape(corrs, (self.n_shifts, L // 2, L)) corrs = np.max(corrs, axis=0) # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. - cl_idx = np.unravel_index( - self.cl_idx, (self.n_theta // 2, self.n_theta) - ) + cl_idx = np.unravel_index(self.cl_idx, (L // 2, L)) prod_corrs = corrs[cl_idx] prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) @@ -1208,8 +1214,8 @@ def _unmix_colors(self, color_vecs): 2D rotation of these vectors (see Section 7.3 of D2 paper for details). """ n_p = color_vecs.shape[0] // 3 - d_theta = 360 // self.n_theta - max_t = 360 // d_theta + 1 + d_theta = self.n_theta // self.n_theta + max_t = self.n_theta // d_theta + 1 def R_theta(theta): R = np.array( @@ -1592,7 +1598,7 @@ def _circ_seq(n1, n2, L): if n2 < n1: n2 += L if n1 == n2: - return np.array(n1).astype(int) + return np.array([n1]).astype(int) % L seq = np.arange(n1, n2 + 1).astype(int) % L @@ -1806,25 +1812,29 @@ def _generate_commonline_angles( return cl_angles, eq2eq_Rij_table - @staticmethod - def _generate_commonline_indices(cl_angles): + def _generate_commonline_indices(self, cl_angles): + """ + Converts pairs pf commonline angles in [0, 360) first into polar Fourier + indices in [0, self.n_theta), then in commonline linear indices. + """ # TODO: This is not accounting for n_theta other than 360! + L = self.n_theta # Flatten the stack og_shape = cl_angles.shape cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) # Fourier ray index - row_sub = np.round(cl_angles[:, 0]).astype("int") % 360 - col_sub = np.round(cl_angles[:, 1]).astype("int") % 360 + row_sub = np.round(cl_angles[:, 0] * L / 360).astype("int") % L + col_sub = np.round(cl_angles[:, 1] * L / 360).astype("int") % L # Restrict Ri in-plane coordinates to <180 degrees. - is_geq_than_pi = row_sub >= 180 - row_sub[is_geq_than_pi] = row_sub[is_geq_than_pi] - 180 - col_sub[is_geq_than_pi] = (col_sub[is_geq_than_pi] + 180) % 360 + is_geq_than_pi = row_sub >= L // 2 + row_sub[is_geq_than_pi] = row_sub[is_geq_than_pi] - L // 2 + col_sub[is_geq_than_pi] = (col_sub[is_geq_than_pi] + (L // 2)) % L - # Convert to linear indices in 360*180 correlation matrix (same as cls_lookup in matlab) - cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(180, 360)) + # Convert to linear indices in 180x360 correlation matrix. + cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(L // 2, L)) # Reshape cl_angles (to match matlab `cls`) cl_angles = cl_angles.reshape(og_shape) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 64889db2e7..1f757dee66 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -108,27 +108,25 @@ def test_estimate_rotations(orient_est): def test_scl_scores(orient_est): - + """ + This test uses a Simulation generated with rotations taken directly + from the D2 algorithm `sphere_grid` of candidate rotations. It is + these candidates which should produce maximum correlation scores since + they match perfectly the Simulation rotations. + """ # Generate lookup data and extract rotations from the candidate `sphere_grid`. + # In this case, we take first 10 candidates from a non-equator viewing direction. orient_est._generate_lookup_data() cand_rots = orient_est.inplane_rotated_grid1 - non_eq_idx = int( - np.argwhere(orient_est.eq_class1 == 0)[0] - ) # Take first non equator viewing direction - rots = cand_rots[ - non_eq_idx, :10 - ] # Take the first 10 inplane rots from non_eq viewing direction. + non_eq_idx = int(np.argwhere(orient_est.eq_class1 == 0)[0]) + rots = cand_rots[non_eq_idx, :10] angles = Rotation(rots).angles # Create a Simulation using those first 10 candidate rotations. - vol = DnSymmetricVolume( - L=orient_est.src.L, order=2, C=1, K=100, dtype=orient_est.dtype, seed=SEED - ).generate() - src = Simulation( n=orient_est.src.n, L=orient_est.src.L, - vols=vol, + vols=orient_est.src.vols, angles=angles, offsets=orient_est.src.offsets, amplitudes=1, @@ -138,11 +136,8 @@ def test_scl_scores(orient_est): # Initialize CL instance with new source. CL = build_CL_from_source(src) - # Generate lookup data and compute scl scores. - # Pre-compute phase-shifted polar Fourier. + # Generate lookup data. CL._compute_shifted_pf() - - # Generate lookup data CL._generate_lookup_data() CL._generate_scl_lookup_data() @@ -154,7 +149,7 @@ def test_scl_scores(orient_est): # Simulation rotations, the maximum correlation for image i should occur at # candidate rotation index (non_eq_idx * CL.n_inplane_rots + i). max_corr_idx = np.argmax(CL.scls_scores, axis=1) - gt_idx = CL.n_inplane_rots * non_eq_idx + np.arange(10) + gt_idx = non_eq_idx * CL.n_inplane_rots + np.arange(10) # Check that self-commonline indices match ground truth. n_match = np.sum(max_corr_idx == gt_idx) @@ -176,7 +171,7 @@ def test_global_J_sync(orient_est): Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of symmetric Rijs + k = (k + p) % 4 # Mix up the ordering of Rijs Rijs[p, k] = rots[i].T @ (g * rots[j]) # J-conjugate a random set of Rijs. @@ -240,7 +235,7 @@ def test_global_J_sync_single_triplet(dtype): Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of symmetric Rijs + k = (k + p) % 4 # Mix up the ordering of Rijs Rijs[p, k] = rots[i].T @ (g * rots[j]) # J-conjugate a random Rij. @@ -264,7 +259,7 @@ def test_sync_colors(orient_est): Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of symmetric Rijs + k = (k + p) % 4 # Mix up the ordering of Rijs Rijs[p, k] = rots[i].T @ (g * rots[j]) # Perform color synchronization. @@ -303,10 +298,8 @@ def test_sync_colors(orient_est): # Rijs_rows_synced should match the ground truth vijs up to the sign of each row. # So we multiply by the sign of the first column of the last two axes to sync signs. - vijs = vijs * np.sign(vijs[..., :, 0])[..., np.newaxis] - Rijs_rows_synced = ( - Rijs_rows_synced * np.sign(Rijs_rows_synced[..., :, 0])[..., np.newaxis] - ) + vijs = vijs * np.sign(vijs[..., 0])[..., None] + Rijs_rows_synced = Rijs_rows_synced * np.sign(Rijs_rows_synced[..., 0])[..., None] np.testing.assert_allclose( vijs, Rijs_rows_synced, atol=utest_tolerance(orient_est.dtype) ) @@ -420,16 +413,16 @@ def build_CL_from_source(source): shift_step = 1 if source.offsets.all() != 0: max_shift = 0.2 - shift_step = 0.1 # Reduce shift steps for non-integer offsets of Simulation. + shift_step = 0.02 # Reduce shift steps for non-integer offsets of Simulation. orient_est = CLSymmetryD2( source, max_shift=max_shift, shift_step=shift_step, - n_theta=360, + n_theta=180, n_rad=source.L, grid_res=350, # Tuned for speed - inplane_res=15, # Tuned for speed + inplane_res=12, # Tuned for speed eq_min_dist=10, # Tuned for speed epsilon=0.001, seed=SEED, From 58644c7c8db79e6147a65b1b094edef8ddd77996 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 31 May 2024 15:55:03 -0400 Subject: [PATCH 042/105] add documentation. --- src/aspire/abinitio/commonline_d2.py | 236 +++++++++++++++------------ tests/test_orient_d2.py | 4 +- 2 files changed, 138 insertions(+), 102 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5e6b0c543d..f3537e0edf 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -46,12 +46,15 @@ def __init__( :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. :param grid_res: Number of sampling points on sphere for projetion directions. - These are generated using the Saaf - Kuijlaars algorithm. Default value is 1200. + These are generated using the Saaf-Kuijlaars algorithm. Default value is 1200. :param inplane_res: The sampling resolution of in-plane rotations for each - projetion direction. Default value is 5. + projetion direction. Default value is 5 degrees. :param eq_min_dist: Width of strip around equator projection directions from - which we DO NOT sample directions. Default value is 7. + which we DO NOT sample directions. Default value is 7 degrees. + :param epsilon: Tolerance for J-synchronization power method. :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ super().__init__( @@ -76,9 +79,8 @@ def __init__( def estimate_rotations(self): """ - Estimate rotation matrices for molecules with D2 symmetry. - - :return: Array of rotation matrices, size n_imgx3x3. + Estimate rotation matrices for molecules with D2 symmetry. Sets the attribute + self.rotations with an array of estimated rotation matrices, size src.nx3x3. """ # Pre-compute phase-shifted polar Fourier. self._compute_shifted_pf() @@ -221,6 +223,82 @@ def _generate_lookup_data(self): self.cl_idx_2, self.cl_angles2 = self._generate_commonline_indices(cl_angles2) self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) + def _generate_commonline_angles( + self, + Ris, + Rjs, + Ri_eq_idx, + Rj_eq_idx, + Ri_eq_class, + Rj_eq_class, + triu=True, + ): + """ + Compute commonline angles induced by the 4 sets of relative rotations + Rij = Ri.T @ g_m @ Rj, m = 0,1,2,3, where g_m is the identity and rotations + about the three axes of symmetry of a D2 symmetric molecule. + + :param Ris: First set of candidate rotations. + :param Rjs: Second set of candidate rotation. + :param Ri_eq_idx: Equator index mask. + :param Rj_eq_idx: Equator index mask. + :param Ri_eq_class: Equator classification for Ris. + :param Rj_eq_class: Equator classification for Rjs. + + :return: Commonline angles induced by relative rotation candidates. + """ + n_rots_i = len(Ris) + n_theta = Ris.shape[1] # Same for Rjs, TODO: Don't call this n_theta + + # Generate upper triangular table of indicators of all pairs which are not + # equators with respect to the same symmetry axis (named unique_pairs). + eq_table = np.outer(Ri_eq_idx, Rj_eq_idx) + in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 + eq2eq_Rij_table = ~(eq_table * in_same_class) + + # This is to match matlab code that uses triu with octant 1 table, but not + # with octants 1 and 2. + if triu: + eq2eq_Rij_table = np.triu(eq2eq_Rij_table, 1) + + n_pairs = np.sum(eq2eq_Rij_table) + idx = 0 + cl_angles = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) + + for i in range(n_rots_i): + unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] + if len(unique_pairs_i) == 0: + continue + Ri = Ris[i] + for j in unique_pairs_i: + Rj = Rjs[j, : n_theta // 2] + for k, g in enumerate(self.gs): + # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj + g_Rj = g * Rj + Rijs = np.transpose(g_Rj, axes=(0, 2, 1)) @ Ri[:, None] + + # Common line indices induced by Rijs + cl_angles[idx, :, :, k, 0] = np.arctan2( + Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] + ) + cl_angles[idx, :, :, k, 1] = np.arctan2( + -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, k, 0] = np.arctan2( + Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] + ) + cl_angles[idx + n_pairs, :, :, k, 1] = np.arctan2( + -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] + ) + + idx += 1 + + # Make all angles non-negative and convert to degrees. + cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) + cl_angles = cl_angles * 180 / np.pi + + return cl_angles, eq2eq_Rij_table + ######################################## # Generate Self-Commonline Lookup Data # ######################################## @@ -291,11 +369,14 @@ def _generate_scl_lookup_data(self): def _generate_scl_angles(self, Ris, eq_idx, eq_class): """ - Generate self-commonline angles. + Generate self-commonline angles. For each candidate rotation a pair of self-commonline + angles are generated for each of the 3 self-commonlines induced by D2 symmetry. :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). :param eq_idx: Equator index mask for Ris. :param eq_class: Equator classification for Ris. + + :return: `scl_angles` of shape (n_sphere_grid, n_inplane_rots, 3, 2). """ # For each candidate rotation Ri we generate the set of 3 self-commonlines. @@ -351,8 +432,7 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): # indices 1 and 2, but flip one common line to antipodal. scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] - # TODO: Maybe a cleaner way to do this. - # Make sure angle range is <= 180 degrees. + # Make sure angle range is < 180 degrees. # p1 marks "equator" self-commonlines where both entries of the first # scl are greater than both entries of the second scl. p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] @@ -373,6 +453,20 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): return np.round(scl_angles * 180 / np.pi) % 360 def _generate_scl_indices(self, scl_angles, eq_class): + """ + Generate self-commonline indices. This includes a set of linear indices for + all candidate rotations as well as lists of self-commonline index ranges for + equator candidates. + + :param scl_angles: Self-commonline angles, shape (n_sphere_grid, n_inplane_rots, 3, 2). + :param eq_class: Equator classification for the sphere_grid points represented + by the first axis of `scl_angles`. + + :returns: + - scl_indices, self-commonline linear indices. + - eq_lin_idx_lists, a list containing a range of self-commonline + indices for each equator candidate. + """ L = self.n_theta # Convert from angles to indices. @@ -388,9 +482,9 @@ def _generate_scl_indices(self, scl_angles, eq_class): n_inplane_rots = scl_angles.shape[1] count_eq = 0 - # eq_lin_idx_lists[1,i,j] registers a list of linear indices of the j'th + # eq_lin_idx_lists[0,i,j] registers a list of linear indices of the j'th # in-plane rotation of the range for the (only) self common line of the i'th - # candidate. eq_lin_idx_lists[2,i,j] registers the actual (integer) angle + # candidate. eq_lin_idx_lists[1,i,j] registers the actual (integer) angle # of the self common line in the 2D Fourier space. Note that we need only # one number since each self common line has radial coordinates of the form # (theta, theta+180). @@ -420,6 +514,10 @@ def _generate_scl_indices(self, scl_angles, eq_class): return scl_indices, eq_lin_idx_lists def _generate_scl_scores_idx_map(self): + """ + Generates lookup tables for maximum likelihood scheme to estimate commonlines + between images. + """ n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) @@ -446,7 +544,7 @@ def _generate_scl_scores_idx_map(self): ) idx += 1 - # First the map for i, where the rows and columns of S are indexed by # double indexes (i,j), 1<=i Date: Mon, 3 Jun 2024 15:46:17 -0400 Subject: [PATCH 043/105] Self-review: Use DnSymmetryGroup for gs, Remove matlab pf convention, Update docstrings, other cleanup. --- src/aspire/abinitio/commonline_d2.py | 76 ++++++++++++---------------- tests/test_orient_d2.py | 15 +++--- 2 files changed, 37 insertions(+), 54 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f3537e0edf..30b6e865ab 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -8,6 +8,7 @@ from aspire.operators import PolarFT from aspire.utils import J_conjugate, Rotation, all_pairs, all_triplets, tqdm, trange from aspire.utils.random import randn +from aspire.volume import DnSymmetryGroup logger = logging.getLogger(__name__) @@ -72,11 +73,15 @@ def __init__( self.eq_min_dist = eq_min_dist self.seed = seed self.epsilon = epsilon - self._generate_gs() + self.triplets = all_triplets(self.n_img) self.pairs, self.pairs_to_linear = all_pairs(self.n_img, return_map=True) self.n_pairs = len(self.pairs) + # D2 symmetry group. + # Rearrange in order Identity, about_x, about_y, about_z. + self.gs = DnSymmetryGroup(order=2, dtype=self.dtype).matrices[[0, 3, 2, 1]] + def estimate_rotations(self): """ Estimate rotation matrices for molecules with D2 symmetry. Sets the attribute @@ -128,8 +133,6 @@ def _compute_shifted_pf(self): # Reconstruct full polar Fourier for use in correlation. pf[:, :, 0] = 0 # Matching matlab convention to zero out the lowest frequency. pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. - pf *= np.sqrt(2) / 2 # Magic number to match matlab pf. Remove after debug. - pf = pf[:, :, ::-1] # also to match matlab. Can remove. self.pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. @@ -163,7 +166,7 @@ def _generate_lookup_data(self): # some symmetry axis only have 2 common lines instead of 4, and must be # treated separately. # We detect such directions by taking a strip of radius - # eq_filter_angle about the 3 great circles perpendicular to the symmetry + # `eq_min_dist` about the 3 great circles perpendicular to the symmetry # axes of D2 (i.e to X,Y and Z axes). eq_idx1, eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) eq_idx2, eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) @@ -215,7 +218,7 @@ def _generate_lookup_data(self): self.eq_idx2, self.eq_class1, self.eq_class2, - triu=False, + same_octant=False, ) # Generate commonline indices. @@ -231,7 +234,7 @@ def _generate_commonline_angles( Rj_eq_idx, Ri_eq_class, Rj_eq_class, - triu=True, + same_octant=True, ): """ Compute commonline angles induced by the 4 sets of relative rotations @@ -244,6 +247,7 @@ def _generate_commonline_angles( :param Rj_eq_idx: Equator index mask. :param Ri_eq_class: Equator classification for Ris. :param Rj_eq_class: Equator classification for Rjs. + :param same_octant: True if both sets of candidates are in the same octant. :return: Commonline angles induced by relative rotation candidates. """ @@ -256,9 +260,8 @@ def _generate_commonline_angles( in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 eq2eq_Rij_table = ~(eq_table * in_same_class) - # This is to match matlab code that uses triu with octant 1 table, but not - # with octants 1 and 2. - if triu: + # For candidates in the same octant only need upper triangle of table. + if same_octant: eq2eq_Rij_table = np.triu(eq2eq_Rij_table, 1) n_pairs = np.sum(eq2eq_Rij_table) @@ -274,7 +277,7 @@ def _generate_commonline_angles( Rj = Rjs[j, : n_theta // 2] for k, g in enumerate(self.gs): # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj - g_Rj = g * Rj + g_Rj = g @ Rj Rijs = np.transpose(g_Rj, axes=(0, 2, 1)) @ Ri[:, None] # Common line indices induced by Rijs @@ -339,7 +342,7 @@ def _generate_scl_lookup_data(self): # non equators then we have the sub list is # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. n_non_eq = np.sum(self.eq_class1 == 0) + np.sum(self.eq_class2 == 0) - non_eq_idx = np.zeros((n_non_eq, int(self.n_inplane_rots))) + non_eq_idx = np.zeros((n_non_eq, self.n_inplane_rots), dtype=int) non_eq_idx[:, 0] = ( np.hstack( ( @@ -352,18 +355,16 @@ def _generate_scl_lookup_data(self): for i in range(1, self.n_inplane_rots): non_eq_idx[:, i] = non_eq_idx[:, 0] + i - self.non_eq_idx = non_eq_idx.astype(int) + self.non_eq_idx = non_eq_idx # Non-topview equator indices. - non_tv_eq_idx = np.concatenate( + self.non_tv_eq_idx = np.concatenate( ( np.where(self.eq_class1 > 0)[0], len(self.eq_class1) + np.where(self.eq_class2 > 0)[0], ) ) - self.non_tv_eq_idx = non_tv_eq_idx.astype(int) - # Generate maps from scl indices to relative rotations. self._generate_scl_scores_idx_map() @@ -384,9 +385,8 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): n_rots = len(Ris) for i in range(n_rots): Ri = Ris[i] - # TODO: Reversing self.gs here to match matlab. Should use as is. - for k, g in enumerate(self.gs[::-1][:3]): - g_Ri = g * Ri + for k, g in enumerate(self.gs[1:]): + g_Ri = g @ Ri Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) @@ -398,7 +398,7 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): # Deal with non top view equators # A non-TV equator has only one self common line. However, we clasify an # equator as an image whose projection direction is at radial distance < - # eq_filter_angle from the great circle perpendicural to a symmetry axis, + # `eq_min_dist` from the great circle perpendicural to a symmetry axis, # and not strictly zero distance. Thus in most cases we get 2 common lines # differing by a small difference in degrees. Actually the calculation above # gives us two NEARLY antipodal lines, so we first flip one of them by @@ -608,7 +608,7 @@ def _compute_scl_scores(self): pf_i_shifted = self.pf_shifted[i] # Compute max correlation over all shifts. - corrs = 2 * np.real(pf_i_shifted @ np.conj(pf_full_i).T) + corrs = np.real(pf_i_shifted @ np.conj(pf_full_i).T) corrs = np.reshape(corrs, (self.n_shifts, n_theta // 2, n_theta)) corrs = np.max(corrs, axis=0) @@ -691,7 +691,7 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = (2 * L) // 360 + r = 2 # Search radius within 2 adjacent rays of normal ray. normal_2_scl_idx = np.zeros((L // 2, 2 * r + 1)) normal_2_scl_idx_1 = np.mod(L // 2 - np.arange(L // 4 - r, L // 4 + r + 1), L) @@ -751,7 +751,7 @@ def _compute_cl_scores(self): pf_j = self.pf_full[j] # Compute maximum correlation over all shifts. - corrs = 2 * np.real(pf_i @ np.conj(pf_j).T) + corrs = np.real(pf_i @ np.conj(pf_j).T) corrs = np.reshape(corrs, (self.n_shifts, L // 2, L)) corrs = np.max(corrs, axis=0) @@ -855,7 +855,7 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): Rjs = inplane_rotated_grid2[Rjs_lin_idx] for k, g in enumerate(self.gs): - Rijs_est[:, k] = Ris_t @ (g * Rjs) + Rijs_est[:, k] = Ris_t @ g @ Rjs Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) @@ -867,13 +867,12 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): def _global_J_sync(self, Rijs): """ - Global J-synchronization of all third row outer products. Given 3x3 matrices Rijs and viis, each - of which might contain a spurious J (ie. vij = J*vi*vj^T*J instead of vij = vi*vj^T), - we return Rijs and viis that all have either a spurious J or not. + Global J-synchronization of all third row outer products. Given n_pairsx4x3x3 matrices Rijs, each + of which might contain a spurious J (ie. Rij = J @ Ri.T @ gs @ Rj @ J instead of Rij = Ri.T @ gs @ Rj), + we return Rijs that all have either a spurious J or not. :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate for the corresponding - outer-product vi*vj^T between the third rows of the rotation matrices Ri and Rj. Each estimate - might have a spurious J independently of other estimates. + outer-product Ri.T @ Rj. Each estimate might have a spurious J independently of other estimates. :return: Rijs, all of which have a spurious J or not. """ @@ -1008,16 +1007,16 @@ def _signs_times_v2(self, J_list, vec): The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of nodes consists of the estimates Rijs and whose set of edges consists of the undirected edges between - all triplets of estimates vij, vjk, and vik, where i Date: Wed, 5 Jun 2024 14:55:21 -0400 Subject: [PATCH 044/105] Add logging to main functions. --- src/aspire/abinitio/commonline_d2.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 30b6e865ab..167ec07912 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -120,6 +120,7 @@ def _compute_shifted_pf(self): """ Pre-compute shifted and full polar Fourier transforms. """ + logger.info("Preparing polar Fourier transform.") pf = self.pf # Generate shift phases. @@ -149,6 +150,7 @@ def _generate_lookup_data(self): """ Generate candidate relative rotations and corresponding common line indices. """ + logger.info("Generating commonline lookup data.") # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter # of sphere because of D2 symmetry redundancy. sphere_grid = self._saff_kuijlaars(self.grid_res) @@ -310,6 +312,7 @@ def _generate_scl_lookup_data(self): """ Generate lookup data for self-commonlines. """ + logger.info("Generating self-commonline lookup data.") # Get self-commonline angles. self.scl_angles1 = self._generate_scl_angles( self.inplane_rotated_grid1, @@ -585,6 +588,7 @@ def _compute_scl_scores(self): """ Compute correlations for self-commonline candidates. """ + logger.info("Computing self-commonline correlation scores.") n_img = self.n_img n_theta = self.n_theta n_eq = len(self.non_tv_eq_idx) @@ -722,6 +726,7 @@ def _compute_cl_scores(self): Run common lines Maximum likelihood procedure for a D2 molecule, to find the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. """ + logger.info("Computing commonline correlation scores.") L = self.n_theta # Map the self common line scores of each 2 candidate rotations R_i,R_j to @@ -876,6 +881,7 @@ def _global_J_sync(self, Rijs): :return: Rijs, all of which have a spurious J or not. """ + logger.info("Performing global handedness synchronization.") # Find best J_configuration. J_list = self._J_configuration(Rijs) @@ -1068,6 +1074,7 @@ def _sync_colors(self, Rijs): The color sync procedure partitions the set of 3-tuples of m'th row outer products into 3 sets of row-consistent outer products up to the sign of each. """ + logger.info("Performing rotations' rows synchronization.") # Generate array of one rank matrices from which we can extract rows. # Matrices are of the form 0.5(Ri^TRj+Ri^TgkRj). Each such matrix can be # written in the form Qi^T*Ik*Qj where Ik is a 3x3 matrix with all zero @@ -1404,6 +1411,7 @@ def _sync_signs(self, rr, c_vec): synchroniztion. At the end all rows of the rotations Ri are exctracted and the matrices Ri are assembled. """ + logger.info("Performing signs synchronization.") # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according # to the color partition established in color synchronization procedure. # The partition is stored in two different arrays each with the purpose @@ -1582,7 +1590,7 @@ def _sync_signs(self, rr, c_vec): signs = np.zeros((3, self.n_pairs), dtype=self.dtype) s_out = np.zeros((3, 3), dtype=self.dtype) - logger.info("Constructing and decomposing 3 sign synchroniztion matrices...") + logger.info("Constructing and decomposing 3 sign synchroniztion matrices.") # The matrix S requires space on order of O(N^4). Instead of storing it # in memory we compute its SVD using the function smat which multiplies # (N over 2)x1 vectors by S. @@ -1618,7 +1626,7 @@ def _sync_signs(self, rr, c_vec): # Adjust the signs of Qij^c in the matrices cMat(:,:,c) for all c=1,2,3 # and 1<=i Date: Mon, 10 Jun 2024 08:46:24 -0400 Subject: [PATCH 045/105] remove unnecesary mod --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 167ec07912..fced6a30fa 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -504,7 +504,7 @@ def _generate_scl_indices(self, scl_angles, eq_class): # Adjust so idx1 is in [0, 180) range. is_geq_than_pi = idx1 >= L // 2 - idx1[is_geq_than_pi] = (idx1[is_geq_than_pi] - L // 2) % (L // 2) + idx1[is_geq_than_pi] = (idx1[is_geq_than_pi] - L // 2) idx2[is_geq_than_pi] = (idx2[is_geq_than_pi] + L // 2) % L # register indices in list. From 9ed4317004b7a46690d8baa1d22c39e887695ad2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 10 Jun 2024 08:54:17 -0400 Subject: [PATCH 046/105] Revert eigs to use largest real. black. --- src/aspire/abinitio/commonline_d2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index fced6a30fa..dd27d5ff61 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -504,7 +504,7 @@ def _generate_scl_indices(self, scl_angles, eq_class): # Adjust so idx1 is in [0, 180) range. is_geq_than_pi = idx1 >= L // 2 - idx1[is_geq_than_pi] = (idx1[is_geq_than_pi] - L // 2) + idx1[is_geq_than_pi] = idx1[is_geq_than_pi] - L // 2 idx2[is_geq_than_pi] = (idx2[is_geq_than_pi] + L // 2) % L # register indices in list. @@ -1106,7 +1106,7 @@ def _sync_colors(self, Rijs): 3 * n_pairs, seed=self.seed ) # Seed eigs initial vector for iterative method v0 = v0 / norm(v0) - vals, colors = la.eigs(color_mat, k=3, which="LM", v0=v0) # Changed from "LR" + vals, colors = la.eigs(color_mat, k=3, which="LR", v0=v0) vals = np.real(vals) colors = np.real(colors) colors = np.sign(colors[0]) * colors # Stable eigs From c755998d952a98147c9d3fb2af4cdf8d888aaaf8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 10 Jun 2024 15:44:19 -0400 Subject: [PATCH 047/105] Add dtype pass-through checks to tests. --- tests/test_orient_d2.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 6c50df17e9..6e8f2741ee 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -103,6 +103,9 @@ def test_estimate_rotations(orient_est): # distance between them is less than 5 degrees. mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=5) + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + def test_scl_scores(orient_est): """ @@ -155,6 +158,9 @@ def test_scl_scores(orient_est): match_tol = 0.89 # match at least 89% with offsets. np.testing.assert_array_less(match_tol, n_match / src.n) + # Check dtype pass-through. + assert CL.scls_scores.dtype == orient_est.dtype + def test_global_J_sync(orient_est): """ @@ -216,6 +222,9 @@ def test_global_J_sync(orient_est): else: np.testing.assert_allclose(Rijs_sync, Rijs) + # Check dtype pass-through. + assert Rijs_sync.dtype == orient_est.dtype + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_global_J_sync_single_triplet(dtype): @@ -301,6 +310,9 @@ def test_sync_colors(orient_est): vijs, Rijs_rows_synced, atol=utest_tolerance(orient_est.dtype) ) + # Check dtype pass-through. + assert Rijs_rows.dtype == orient_est.dtype + def test_sync_signs(orient_est): """ @@ -330,6 +342,9 @@ def test_sync_signs(orient_est): rots_est = orient_est._sync_signs(vijs, colors) mean_aligned_angular_distance(rots, rots_est, degree_tol=1e-5) + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + #################### # Helper Functions # From ef7bbf967fec145cc2eab27b12da5c001db66b8a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 17 Jun 2024 08:55:21 -0400 Subject: [PATCH 048/105] remove einsum for Garrett --- src/aspire/abinitio/commonline_d2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index dd27d5ff61..5df3e24f83 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -945,7 +945,9 @@ def _compare_rots(self, Rij, Rjk_t, Rik): corresponding to best configuration for the provided triplet of relative rotations. """ - prod_arr = np.einsum("nij,mjk->nmik", Rik, Rjk_t) + # We compute the four sets of 4^3 norms |Rik @ Rjk.T - Rij| + # See equation (6.11) in publication. + prod_arr = Rik[:, None, :, :] @ Rjk_t[None, :, :, :] arr = np.zeros((8, 8, 3, 3), dtype=self.dtype) arr[0:4, 0:4] = prod_arr - Rij[0] From 9c6b4d198116872cc3555b6045e222a8bde79804 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 17 Jun 2024 09:41:26 -0400 Subject: [PATCH 049/105] Remove more einsums. --- src/aspire/abinitio/commonline_d2.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5df3e24f83..69798fd145 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -18,7 +18,7 @@ class CLSymmetryD2(CLOrient3D): Define a class to estimate 3D orientations using common lines methods for molecules with D2 (dihedral) symmetry. - The related publications are: + Corresponding publication: E. Rosen and Y. Shkolnisky, Common lines ab-initio reconstruction of D2-symmetric molecules, SIAM Journal on Imaging Sciences, volume 13-4, p. 1898-1994, 2020 @@ -94,7 +94,7 @@ def estimate_rotations(self): self._generate_lookup_data() self._generate_scl_lookup_data() - # Compute common-line scores. + # Compute self common-line scores. self._compute_scl_scores() # Compute common-lines and estimate relative rotations Rijs. @@ -947,7 +947,7 @@ def _compare_rots(self, Rij, Rjk_t, Rik): """ # We compute the four sets of 4^3 norms |Rik @ Rjk.T - Rij| # See equation (6.11) in publication. - prod_arr = Rik[:, None, :, :] @ Rjk_t[None, :, :, :] + prod_arr = Rik[:, None] @ Rjk_t[None] arr = np.zeros((8, 8, 3, 3), dtype=self.dtype) arr[0:4, 0:4] = prod_arr - Rij[0] @@ -1147,11 +1147,9 @@ def _match_colors(self, Rijs_rows): ik = self.pairs_to_linear[i, k] # For r=1:3 compute 3*3 products v_{ji}(r)v_{ik}v_{kj} - prod_arr = np.einsum("nij,mjk->mnik", Rijs_rows[ik], Rijs_rows_t[jk]) + prod_arr = Rijs_rows[ik, None] @ Rijs_rows_t[jk, :, None] prod_arr_tmp = prod_arr.copy() - prod_arr = np.einsum( - "nij,mjk->nmik", Rijs_rows_t[ij], prod_arr.reshape((9, 3, 3)) - ) + prod_arr = Rijs_rows_t[ij, :, None] @ prod_arr_tmp.reshape((9, 3, 3))[None] prod_arr = np.transpose( prod_arr.reshape((3, 3, 3, 9), order="F"), (2, 1, 0, 3) ) @@ -1191,13 +1189,10 @@ def _match_colors(self, Rijs_rows): # For r=1:3 compute 3*3 products v_{ij}(r)v_{jk}v_{ki} and compare to # Compare to v_{ii}(r)=v_{ij}v_{ji} prod_arr = np.transpose(prod_arr_tmp, (0, 1, 3, 2)) - prod_arr = np.einsum( - "nij,mjk->mnik", Rijs_rows[ij], prod_arr.reshape(9, 3, 3) - ) + prod_arr = Rijs_rows[ij, :, None] @ prod_arr.reshape((9, 3, 3))[None] prod_arr = np.transpose( prod_arr.reshape((3, 3, 3, 9), order="F"), (1, 0, 2, 3) ) - # Commented out calculations in matlab here. # Compare to v_{ii}(r)=v_{ik}v_{ki}. self_prods = Rijs_rows[ik] @ Rijs_rows_t[ik] From f011ebb93346a6cc88a69cf8e68099801be36cca Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 17 Jun 2024 10:35:01 -0400 Subject: [PATCH 050/105] line wrap docstrings --- src/aspire/abinitio/commonline_d2.py | 47 +++++++++++++++------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 69798fd145..8ccc0fe81c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -872,12 +872,14 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): def _global_J_sync(self, Rijs): """ - Global J-synchronization of all third row outer products. Given n_pairsx4x3x3 matrices Rijs, each - of which might contain a spurious J (ie. Rij = J @ Ri.T @ gs @ Rj @ J instead of Rij = Ri.T @ gs @ Rj), - we return Rijs that all have either a spurious J or not. + Global J-synchronization of all third row outer products. Given n_pairsx4x3x3 + matrices Rijs, each of which might contain a spurious J, ie. + Rij = J @ Ri.T @ gs @ Rj @ J instead of Rij = Ri.T @ gs @ Rj, we return Rijs + that all have either a spurious J or not. - :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate for the corresponding - outer-product Ri.T @ Rj. Each estimate might have a spurious J independently of other estimates. + :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate + for the corresponding outer-product Ri.T @ Rj. Each estimate might have a + spurious J independently of other estimates. :return: Rijs, all of which have a spurious J or not. """ @@ -1013,27 +1015,30 @@ def _signs_times_v2(self, J_list, vec): """ Multiplication of the J-synchronization matrix by a candidate eigenvector. - The J-synchronization matrix is a matrix representation of the handedness graph, Gamma, whose set of - nodes consists of the estimates Rijs and whose set of edges consists of the undirected edges between - all triplets of estimates Rij, Rjk, and Rik, where i Date: Mon, 17 Jun 2024 14:22:57 -0400 Subject: [PATCH 051/105] remove debug comment. --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 8ccc0fe81c..b41bf2a1c6 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1400,7 +1400,7 @@ def R_theta(theta): else: colors[i] = p_i_sqr colors = colors.flatten() - # colors = 2 - colors # For debug. remove + return colors, best_unmix ##################### From 6cf681d212b3bcf68f39cf722c341418f682b326 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 22 Jul 2024 11:03:34 -0400 Subject: [PATCH 052/105] remove CAPS --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index b41bf2a1c6..e1791ba847 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -51,7 +51,7 @@ def __init__( :param inplane_res: The sampling resolution of in-plane rotations for each projetion direction. Default value is 5 degrees. :param eq_min_dist: Width of strip around equator projection directions from - which we DO NOT sample directions. Default value is 7 degrees. + which we do not sample directions. Default value is 7 degrees. :param epsilon: Tolerance for J-synchronization power method. :param seed: Optional seed for RNG. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). From 05fea9d428b00ba9d5cf5505185711d858e21060 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 22 Jul 2024 11:11:54 -0400 Subject: [PATCH 053/105] reshape pf.shifted --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index e1791ba847..b3b45500b0 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -137,7 +137,7 @@ def _compute_shifted_pf(self): self.pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. - pf_shifted = (pf * shift_phases[:, None, None]).swapaxes(0, 1) + pf_shifted = pf[:, None] * shift_phases[None, :, None] self.pf_shifted = pf_shifted.reshape( (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) ) From 3a38b88bb7e1a9d140dedc8dda36df37b59ce8bc Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 22 Jul 2024 11:36:01 -0400 Subject: [PATCH 054/105] use count_nonzero instead of sum --- src/aspire/abinitio/commonline_d2.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index b3b45500b0..dcc6536972 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -266,7 +266,7 @@ def _generate_commonline_angles( if same_octant: eq2eq_Rij_table = np.triu(eq2eq_Rij_table, 1) - n_pairs = np.sum(eq2eq_Rij_table) + n_pairs = np.count_nonzero(eq2eq_Rij_table) idx = 0 cl_angles = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) @@ -344,7 +344,9 @@ def _generate_scl_lookup_data(self): # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are # non equators then we have the sub list is # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. - n_non_eq = np.sum(self.eq_class1 == 0) + np.sum(self.eq_class2 == 0) + n_non_eq = np.count_nonzero(self.eq_class1 == 0) + np.count_nonzero( + self.eq_class2 == 0 + ) non_eq_idx = np.zeros((n_non_eq, self.n_inplane_rots), dtype=int) non_eq_idx[:, 0] = ( np.hstack( @@ -525,7 +527,7 @@ def _generate_scl_scores_idx_map(self): n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) # First the map for i 0: Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) if n_est_in_oct1 <= len(lin_idx): @@ -818,7 +820,7 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): unique_pairs = self.eq2eq_Rij_table_12 n_theta = self.n_inplane_rots - n_lookup_pairs = np.sum(unique_pairs, dtype=np.int64) + n_lookup_pairs = np.count_nonzero(unique_pairs) n_rots = len(self.sphere_grid1) if octant == 1: n_rots2 = n_rots @@ -1687,7 +1689,7 @@ def _calc_Rij_prods(self, c_mat_5d, i, j, c): # In case we get a zero score arbitrarily choose sign +1. ij_signs = np.sum(Rij, axis=(-2, -1)) zeros_idx = ij_signs == 0 - if np.sum(zeros_idx) > 0: + if np.count_nonzero(zeros_idx) > 0: ij_signs[zeros_idx] = 1 return np.sign(ij_signs) @@ -1793,7 +1795,7 @@ def _mark_equators(sphere_grid, eq_filter_angle): # Mark all views close to an equator. eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) - n_eqs = np.sum(angular_dists > eq_min_dist, axis=1) + n_eqs = np.count_nonzero(angular_dists > eq_min_dist, axis=1) eq_idx = n_eqs > 0 # Classify equators. From fa8c36ba9b3e7059fa6bfd4920dbeae7917ac343 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 22 Jul 2024 11:57:09 -0400 Subject: [PATCH 055/105] fix typos. --- src/aspire/abinitio/commonline_d2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index dcc6536972..b06114241b 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -403,7 +403,7 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): # Deal with non top view equators # A non-TV equator has only one self common line. However, we clasify an # equator as an image whose projection direction is at radial distance < - # `eq_min_dist` from the great circle perpendicural to a symmetry axis, + # `eq_min_dist` from the great circle perpendicular to a symmetry axis, # and not strictly zero distance. Thus in most cases we get 2 common lines # differing by a small difference in degrees. Actually the calculation above # gives us two NEARLY antipodal lines, so we first flip one of them by @@ -454,7 +454,7 @@ def _generate_scl_angles(self, Ris, eq_idx, eq_class): + scl_angles[eq_class > 0] * ~p[:, :, None, None] ) - # Convert from angles [0,2*pi) to degrees [0, 360). + # Convert from radians [0,2*pi) to degrees [0, 360). return np.round(scl_angles * 180 / np.pi) % 360 def _generate_scl_indices(self, scl_angles, eq_class): @@ -596,10 +596,9 @@ def _compute_scl_scores(self): n_eq = len(self.non_tv_eq_idx) n_inplane = self.n_inplane_rots - # Run ML in parallel + # Prepare self-commonline indices. scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) M = len(scl_matrix) // 3 - corrs_out = np.zeros((n_img, M), dtype=self.dtype) scl_idx = scl_matrix.reshape(M, 3) # Get non-equator indices to use with corrs matrix. @@ -609,6 +608,7 @@ def _compute_scl_scores(self): scl_idx[non_eq_lin_idx].flatten(), (n_theta // 2, n_theta) ) + corrs_out = np.zeros((n_img, M), dtype=self.dtype) for i in trange(n_img): pf_full_i = self.pf_full[i] pf_i_shifted = self.pf_shifted[i] @@ -1102,7 +1102,7 @@ def _sync_colors(self, Rijs): # Compute eigenvectors of color matrix. This is just a matrix of dimensions # 3(N choose 2)x3(N choose 2) where each entry corresponds to a pair of - # matrices {Qi^T*Ir*Qj} and {Qr^T*Il*Qj} and eqauls \delta_rl. + # matrices {Qi^T*Ir*Qj} and {Qr^T*Il*Qj} and equals \delta_rl. # The 2 leading eigenvectors span a linear subspace which contains a # vector which encodes the partition. All the entries of the vector are # either 1,0 or -1, where the number encodes which the index r in Ir. From 645d7f9bb03e7ddfff9c782eeb4488742e817149 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 23 Jul 2024 10:51:25 -0400 Subject: [PATCH 056/105] Fix comment to use n_theta. --- src/aspire/abinitio/commonline_d2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index b06114241b..e3f20cd9e8 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -657,11 +657,11 @@ def _all_eq_measures(self, corrs): :return: (n_theta // 2) likelihood scores. """ - # First compute the eq measure (corrs(scl-k,scl+k) for k=1:90) + # First compute the eq measure (corrs(scl-k,scl+k) for k=1:n_theta // 4) # An equator image of a D2 molecule has the following property: If t_i is # the angle of one of the rays of the self common line then all the pairs of - # rays of the form (t_i-k,t_i+k) for k=1:90 are identical. For each t_i we - # average over correlations between the lines (t_i-k,t_i+k) for k=1:90 + # rays of the form (t_i-k,t_i+k) for k=1:n_theta // 4 are identical. For each t_i we + # average over correlations between the lines (t_i-k,t_i+k) for k=1:n_theta // 4 # to measure the likelihood that the image is an equator and the ray (line) # with angle t_i is a self common line. # (This first loop can be done once outside this function and then pass From a64c60ebfe6630239e821a3103a9f98249f57142 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 30 Jul 2024 10:44:57 -0400 Subject: [PATCH 057/105] reshape cl_angles. Remove unused attribute self.cl_angles*. --- src/aspire/abinitio/commonline_d2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index e3f20cd9e8..69b45cb224 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -224,8 +224,8 @@ def _generate_lookup_data(self): ) # Generate commonline indices. - self.cl_idx_1, self.cl_angles1 = self._generate_commonline_indices(cl_angles1) - self.cl_idx_2, self.cl_angles2 = self._generate_commonline_indices(cl_angles2) + self.cl_idx_1 = self._generate_commonline_indices(cl_angles1) + self.cl_idx_2 = self._generate_commonline_indices(cl_angles2) self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) def _generate_commonline_angles( @@ -268,7 +268,7 @@ def _generate_commonline_angles( n_pairs = np.count_nonzero(eq2eq_Rij_table) idx = 0 - cl_angles = np.zeros((2 * n_pairs, n_theta, n_theta // 2, 4, 2)) + cl_angles = np.zeros((2, n_pairs, n_theta, n_theta // 2, 4, 2)) for i in range(n_rots_i): unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] @@ -283,16 +283,16 @@ def _generate_commonline_angles( Rijs = np.transpose(g_Rj, axes=(0, 2, 1)) @ Ri[:, None] # Common line indices induced by Rijs - cl_angles[idx, :, :, k, 0] = np.arctan2( + cl_angles[0, idx, :, :, k, 0] = np.arctan2( Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] ) - cl_angles[idx, :, :, k, 1] = np.arctan2( + cl_angles[0, idx, :, :, k, 1] = np.arctan2( -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] ) - cl_angles[idx + n_pairs, :, :, k, 0] = np.arctan2( + cl_angles[1, idx, :, :, k, 0] = np.arctan2( Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] ) - cl_angles[idx + n_pairs, :, :, k, 1] = np.arctan2( + cl_angles[1, idx, :, :, k, 1] = np.arctan2( -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] ) @@ -475,7 +475,7 @@ def _generate_scl_indices(self, scl_angles, eq_class): L = self.n_theta # Convert from angles to indices. - scl_indices, _ = self._generate_commonline_indices(scl_angles) + scl_indices = self._generate_commonline_indices(scl_angles) scl_angles = np.mod(np.round(scl_angles / (2 * np.pi) * L), L).astype(int) # Create candidate common line linear indices lists for equators. @@ -1887,4 +1887,4 @@ def _generate_commonline_indices(self, cl_angles): cl_angles = np.rint(cl_angles.reshape(og_shape)).astype(int) # Return as integer indices. - return cl_idx, cl_angles + return cl_idx From b436ddd8c5d009496bba65cf009314177acd43e6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 30 Jul 2024 11:47:23 -0400 Subject: [PATCH 058/105] replace loop with broadcast. --- src/aspire/abinitio/commonline_d2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 69b45cb224..6c114efc74 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -357,8 +357,7 @@ def _generate_scl_lookup_data(self): ) * self.n_inplane_rots ) - for i in range(1, self.n_inplane_rots): - non_eq_idx[:, i] = non_eq_idx[:, 0] + i + non_eq_idx[:, 1:] = non_eq_idx[:, [0]] + np.arange(1, self.n_inplane_rots) self.non_eq_idx = non_eq_idx From e2678bcb1412c8f9d8363b0b1c80bc5b3f5e69fc Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 30 Jul 2024 13:26:05 -0400 Subject: [PATCH 059/105] Clean up _all_eq_measures using broadcasting. --- src/aspire/abinitio/commonline_d2.py | 53 +++++++++++++--------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 6c114efc74..5e29f6a522 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -666,27 +666,27 @@ def _all_eq_measures(self, corrs): # (This first loop can be done once outside this function and then pass # idx as an argument). L = self.n_theta - idx = np.zeros((L // 2, L // 4, 2)) - idx_1 = np.mod( - np.vstack((-np.arange(1, L // 4 + 1), np.arange(1, L // 4 + 1))), L - ) - idx[0, :, :] = idx_1.T - for k in range(1, L // 2): - idx[k, :, :] = np.mod(idx_1.T + k, L) - idx = np.mod(idx, L) + L_half = L // 2 + + # Generate indices using broadcasting. + t_i = np.arange(L_half)[:, None, None] + k_vals = np.arange(1, L // 4 + 1)[None, :, None] + neg_pos_k = np.array([-1, 1])[None, None, :] + + # Calculate indices, shape: (L//2, L//4, 2). + idx = np.mod(t_i + k_vals * neg_pos_k, L) # Convert to Fourier ray indices. idx_1 = idx[:, :, 0].flatten() idx_2 = idx[:, :, 1].flatten() - # Make all Ri coordinates < 180 and compute linear indices for corrrelations - is_geq_than_pi = idx_1 >= L // 2 - idx_1[is_geq_than_pi] = idx_1[is_geq_than_pi] - (L // 2) - idx_2[is_geq_than_pi] = (idx_2[is_geq_than_pi] + (L // 2)) % L + # Adjust idx_1 to be within [0, 180) and adjust idx_2 accordingly. + is_geq_than_pi = idx_1 >= L_half + idx_1[is_geq_than_pi] -= L_half + idx_2[is_geq_than_pi] = (idx_2[is_geq_than_pi] + L_half) % L - # Compute correlations. - eq_corrs = corrs[idx_1.astype(int), idx_2.astype(int)] - eq_corrs = eq_corrs.reshape(L // 2, L // 4) + # Compute correlations + eq_corrs = corrs[idx_1, idx_2].reshape(L_half, L // 4) corrs_mean = np.mean(eq_corrs, axis=1) # Now compute correlations for normals to scls. @@ -698,22 +698,19 @@ def _all_eq_measures(self, corrs): # equator and t_i+0.5*pi is the normal to its self common line. r = 2 # Search radius within 2 adjacent rays of normal ray. - normal_2_scl_idx = np.zeros((L // 2, 2 * r + 1)) - normal_2_scl_idx_1 = np.mod(L // 2 - np.arange(L // 4 - r, L // 4 + r + 1), L) - normal_2_scl_idx[0, :] = normal_2_scl_idx_1 - for k in range(1, L // 2): - normal_2_scl_idx[k, :] = np.mod(normal_2_scl_idx_1 + k, L) + # Generate indices for normal to scl index. + normal_2_scl_idx_0 = ( + L_half - np.arange(L_half // 2 - r, L_half // 2 + r + 1) + ) % L + normal_2_scl_idx = (normal_2_scl_idx_0 + np.arange(L_half).reshape(-1, 1)) % L - # Make all Ri coordinates <=180 and compute linear indices for corrrelations - is_geq_than_pi = normal_2_scl_idx >= L // 2 - normal_2_scl_idx[is_geq_than_pi] = normal_2_scl_idx[is_geq_than_pi] - (L // 2) + # Adjust indices to be within [0, 180) range. + normal_2_scl_idx = np.where( + normal_2_scl_idx >= L_half, normal_2_scl_idx - L_half, normal_2_scl_idx + ) # Compute correlations for normals. - normal_2_scl_idx = normal_2_scl_idx.flatten() - normal_corrs = corrs[ - normal_2_scl_idx.astype(int), normal_2_scl_idx.astype(int) + (L // 2) - ] - normal_corrs = normal_corrs.reshape(L // 2, 2 * r + 1) + normal_corrs = corrs[normal_2_scl_idx, normal_2_scl_idx + L_half] normal_corrs_max = np.max(normal_corrs, axis=1) return corrs_mean * normal_corrs_max From aba51822d6bb6c78ee191249219b2399a46fbec8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 30 Jul 2024 14:49:47 -0400 Subject: [PATCH 060/105] Vectorize _compute_cl_scores. --- src/aspire/abinitio/commonline_d2.py | 72 +++++++++++++++------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 5e29f6a522..680f22e79a 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -726,9 +726,9 @@ def _compute_cl_scores(self): """ logger.info("Computing commonline correlation scores.") L = self.n_theta + n_pairs = self.n_img * (self.n_img - 1) // 2 - # Map the self common line scores of each 2 candidate rotations R_i,R_j to - # the respective relative rotation candidate R_i^TR_j. + # Map the self common line scores of each 2 candidate rotations R_i, R_j n_lookup_1 = len(self.scl_idx_1) // 3 oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) oct2_ij_map = self.oct2_ij_map @@ -736,50 +736,56 @@ def _compute_cl_scores(self): oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) - # Allocate output variables. - n_pairs = self.n_img * (self.n_img - 1) // 2 + # Gather commonline indices and unravel to index into correlations. + cl_idx = np.unravel_index(self.cl_idx, (L // 2, L)) + + # Allocate output variables corrs_idx = np.zeros(n_pairs, dtype=np.int64) corrs_out = np.zeros(n_pairs, dtype=self.dtype) - ij_idx = 0 - # Search for common lines between pairs of projections. + ij_idx = 0 pbar = tqdm( desc="Searching for commonlines between pairs of images", total=n_pairs ) - for i in range(self.n_img): + + # Vectorize over pairs of images + for i in range(self.n_img - 1): pf_i = self.pf_shifted[i] scores_i = self.scls_scores[i] - for j in range(i + 1, self.n_img): - pf_j = self.pf_full[j] - - # Compute maximum correlation over all shifts. - corrs = np.real(pf_i @ np.conj(pf_j).T) - corrs = np.reshape(corrs, (self.n_shifts, L // 2, L)) - corrs = np.max(corrs, axis=0) - - # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. - cl_idx = np.unravel_index(self.cl_idx, (L // 2, L)) - - prod_corrs = corrs[cl_idx] - prod_corrs = prod_corrs.reshape(len(prod_corrs) // 4, 4) - prod_corrs = np.prod(prod_corrs, axis=1) - - # Incorporate scores of individual rotations from self-commonlines. - scores_j = self.scls_scores[j] - scores_ij = scores_i[ij_map[:, 0]] * scores_j[ij_map[:, 1]] + # Gather all pf_j in one array for vectorized computation + pf_js = self.pf_full[i + 1 : self.n_img] + n_pf_js = pf_js.shape[0] + + # Compute maximum correlation over all shifts for all pf_j + corrs = np.real(pf_i @ np.conj(pf_js.transpose(0, 2, 1))) + corrs = corrs.reshape(n_pf_js, self.n_shifts, L // 2, L) + corrs = np.max(corrs, axis=1) # Max over shifts + + # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. + # Vectorize extraction and processing of correlations + prod_corrs = corrs[:, cl_idx[0], cl_idx[1]] + prod_corrs = prod_corrs.reshape(n_pf_js, len(prod_corrs[0]) // 4, 4) + prod_corrs = np.prod(prod_corrs, axis=2) + + # Incorporate scores of individual rotations from self-commonlines + scores_js = self.scls_scores[i + 1 : self.n_img] + scores_ij = scores_i[ij_map[:, 0]] * scores_js[:, ij_map[:, 1]] + + # Find maximum correlations and update results + prod_corrs = prod_corrs * scores_ij + max_indices = np.argmax(prod_corrs, axis=1) + corrs_idx[ij_idx : ij_idx + len(max_indices)] = max_indices + corrs_out[ij_idx : ij_idx + len(max_indices)] = prod_corrs[ + np.arange(len(max_indices)), max_indices + ] - # Find maximum correlations. - prod_corrs = prod_corrs * scores_ij - max_idx = np.argmax(prod_corrs) - corrs_idx[ij_idx] = max_idx - corrs_out[ij_idx] = prod_corrs[max_idx] - ij_idx += 1 + ij_idx += len(max_indices) + pbar.update(len(max_indices)) - pbar.update() pbar.close() - # Get estimated relative viewing directions. + # Get estimated relative viewing directions self.corrs_idx = corrs_idx self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) From 8a3db369e57c3b1c33e38066016e79bca1a5b265 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 5 Aug 2024 09:15:46 -0400 Subject: [PATCH 061/105] Broadcast when computing Rijs. --- src/aspire/abinitio/commonline_d2.py | 39 +++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 680f22e79a..6eacd50d05 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -277,24 +277,27 @@ def _generate_commonline_angles( Ri = Ris[i] for j in unique_pairs_i: Rj = Rjs[j, : n_theta // 2] - for k, g in enumerate(self.gs): - # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj - g_Rj = g @ Rj - Rijs = np.transpose(g_Rj, axes=(0, 2, 1)) @ Ri[:, None] - - # Common line indices induced by Rijs - cl_angles[0, idx, :, :, k, 0] = np.arctan2( - Rijs[:, :, 2, 0], -Rijs[:, :, 2, 1] - ) - cl_angles[0, idx, :, :, k, 1] = np.arctan2( - -Rijs[:, :, 0, 2], Rijs[:, :, 1, 2] - ) - cl_angles[1, idx, :, :, k, 0] = np.arctan2( - Rijs[:, :, 0, 2], -Rijs[:, :, 1, 2] - ) - cl_angles[1, idx, :, :, k, 1] = np.arctan2( - -Rijs[:, :, 2, 0], Rijs[:, :, 2, 1] - ) + + # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj + Rijs = ( + np.transpose(Ri, axes=(0, 2, 1))[:, None, None] + @ self.gs + @ Rj[:, None] + ) + + # Common line indices induced by Rijs + cl_angles[0, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 0, 2], Rijs[..., 1, 2] + ) + cl_angles[0, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 2, 0], -Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 2, 0], Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 0, 2], -Rijs[..., 1, 2] + ) idx += 1 From 5c498fb6e920a055ba8e99fd11809c5c81b08d6e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 5 Aug 2024 10:00:05 -0400 Subject: [PATCH 062/105] loop -> broadcast --- src/aspire/abinitio/commonline_d2.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 6eacd50d05..e8374ff0e0 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -852,8 +852,6 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) ) - Rijs_est = np.zeros((n_pairs, 4, 3, 3), dtype=self.dtype) - # Convert linear indices of unique table to linear indices of index pairs table. idx_vec = np.arange(np.prod(unique_pairs.shape)) unique_lin_idx = idx_vec[unique_pairs.flatten()] @@ -865,9 +863,7 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) Rjs = inplane_rotated_grid2[Rjs_lin_idx] - - for k, g in enumerate(self.gs): - Rijs_est[:, k] = Ris_t @ g @ Rjs + Rijs_est = Ris_t[:, None] @ self.gs @ Rjs[:, None] Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) From 8f9b0236bd6f88d8d9ba66b18ff9170a7cfb275d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 5 Aug 2024 10:15:17 -0400 Subject: [PATCH 063/105] rename func --- src/aspire/abinitio/commonline_d2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index e8374ff0e0..d502be9274 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1000,7 +1000,7 @@ def _J_sync_power_method(self, J_list): ) while itr < max_iters and residual > epsilon: itr += 1 - vec_new = self._signs_times_v2(J_list, vec) + vec_new = self._signs_times_v(J_list, vec) vec_new = vec_new / norm(vec_new) residual = norm(vec_new - vec) vec = vec_new @@ -1014,7 +1014,7 @@ def _J_sync_power_method(self, J_list): return J_sync - def _signs_times_v2(self, J_list, vec): + def _signs_times_v(self, J_list, vec): """ Multiplication of the J-synchronization matrix by a candidate eigenvector. From 8901213e8214590727495783101ea2f67b1f1260 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 5 Aug 2024 11:10:16 -0400 Subject: [PATCH 064/105] _compare_rots docstring and broadcast. --- src/aspire/abinitio/commonline_d2.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index d502be9274..6ecb2ed008 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -949,21 +949,25 @@ def _compare_rots(self, Rij, Rjk_t, Rik): Compute norms for the 4 J-configurations and return indices corresponding to best configuration for the provided triplet of relative rotations. + + :param Rij: Relative rotation between i'th and j'th candidate rotations + of shape (4, 3, 3). + :param Rjk_t: Transpose of relative rotation between j'th and k'th candidate + rotations of shape (4, 3, 3). + :param Rik: Relative rotation between i'th and k'th candidate rotations + of shape (4, 3, 3). + :return: Score for this J-configuration of the given rotation triplet. """ # We compute the four sets of 4^3 norms |Rik @ Rjk.T - Rij| # See equation (6.11) in publication. prod_arr = Rik[:, None] @ Rjk_t[None] + diff_arr = prod_arr[:, :, None] - Rij + diff_arr = diff_arr.reshape((64, 9)) + norm_arr = np.sum(diff_arr**2, axis=1) - arr = np.zeros((8, 8, 3, 3), dtype=self.dtype) - arr[0:4, 0:4] = prod_arr - Rij[0] - arr[0:4, 4:8] = prod_arr - Rij[1] - arr[4:8, 0:4] = prod_arr - Rij[2] - arr[4:8, 4:8] = prod_arr - Rij[3] - - arr = arr.reshape((64, 9)) - arr = np.sum(arr**2, axis=1) - - m = np.sort(arr.flatten()) + # For perfect estimates, 16 of the 64 norms will equal zero. + # We sum over the smallest 16 values to get a vote for this J-configuration. + m = np.sort(norm_arr) vote = np.sum(m[:16]) return vote From 2a4c6aa9a36a8b0e4bfa882b7e47114fa201aa89 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 6 Aug 2024 09:31:08 -0400 Subject: [PATCH 065/105] broadcasting --- src/aspire/abinitio/commonline_d2.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 6ecb2ed008..d3771c9168 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1171,15 +1171,11 @@ def _match_colors(self, Rijs_rows): self_prods = self_prods.reshape(3, 9) prod_arr1 = prod_arr.copy() - prod_arr1[:, :, 0, :] = prod_arr1[:, :, 0, :] - self_prods[0] - prod_arr1[:, :, 1, :] = prod_arr1[:, :, 1, :] - self_prods[1] - prod_arr1[:, :, 2, :] = prod_arr1[:, :, 2, :] - self_prods[2] + prod_arr1 -= self_prods norms1 = np.sum(prod_arr1**2, axis=3) prod_arr2 = prod_arr.copy() - prod_arr2[:, :, 0, :] = prod_arr2[:, :, 0, :] + self_prods[0] - prod_arr2[:, :, 1, :] = prod_arr2[:, :, 1, :] + self_prods[1] - prod_arr2[:, :, 2, :] = prod_arr2[:, :, 2, :] + self_prods[2] + prod_arr2 += self_prods norms2 = np.sum(prod_arr2**2, axis=3) # Compare to v_{jj}(r)=v_{jk}v_{kj}. @@ -1187,15 +1183,11 @@ def _match_colors(self, Rijs_rows): self_prods = self_prods.reshape(3, 9) prod_arr1 = prod_arr.copy() - prod_arr1[0] = prod_arr1[0] - self_prods[0] - prod_arr1[1] = prod_arr1[1] - self_prods[1] - prod_arr1[2] = prod_arr1[2] - self_prods[2] + prod_arr1 -= self_prods[:, None, None] norms1 = norms1 + np.sum(prod_arr1**2, axis=3) prod_arr2 = prod_arr.copy() - prod_arr2[0] = prod_arr2[0] + self_prods[0] - prod_arr2[1] = prod_arr2[1] + self_prods[1] - prod_arr2[2] = prod_arr2[2] + self_prods[2] + prod_arr2 += self_prods[:, None, None] norms2 = norms2 + np.sum(prod_arr2**2, axis=3) # For r=1:3 compute 3*3 products v_{ij}(r)v_{jk}v_{ki} and compare to @@ -1211,15 +1203,11 @@ def _match_colors(self, Rijs_rows): self_prods = self_prods.reshape(3, 9) prod_arr1 = prod_arr.copy() - prod_arr1[:, 0] = prod_arr1[:, 0] - self_prods[0] - prod_arr1[:, 1] = prod_arr1[:, 1] - self_prods[1] - prod_arr1[:, 2] = prod_arr1[:, 2] - self_prods[2] + prod_arr1 -= self_prods[None, :, None] norms1 = norms1 + np.sum(prod_arr1**2, axis=3) prod_arr2 = prod_arr.copy() - prod_arr2[:, 0] = prod_arr2[:, 0] + self_prods[0] - prod_arr2[:, 1] = prod_arr2[:, 1] + self_prods[1] - prod_arr2[:, 2] = prod_arr2[:, 2] + self_prods[2] + prod_arr2 += self_prods[None, :, None] norms2 = norms2 + np.sum(prod_arr2**2, axis=3) norms = np.minimum(norms1, norms2) From 1e8e05f6caed36e6b05e0945a3dfb3cc2ff90c07 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 6 Aug 2024 14:00:48 -0400 Subject: [PATCH 066/105] Reshape arrays and use cleaner indexing. --- src/aspire/abinitio/commonline_d2.py | 41 ++++++++-------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index d3771c9168..9871be3325 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1417,19 +1417,13 @@ def _sync_signs(self, rr, c_vec): # o(N^2) which doesn't pose a constraint for inputs on the scale of 10^3-10^4. c_mat_5d = np.zeros((self.n_img, self.n_img, 3, 3, 3), dtype=self.dtype) c_mat_4d = np.zeros((self.n_pairs, 3, 3, 3), dtype=self.dtype) + c_vec = c_vec.reshape(self.n_pairs, 3) for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): ij = self.pairs_to_linear[i, j] - c_mat_5d[i, j, c_vec[3 * ij]] = rr[ij, 0] - c_mat_5d[i, j, c_vec[3 * ij + 1]] = rr[ij, 1] - c_mat_5d[i, j, c_vec[3 * ij + 2]] = rr[ij, 2] - c_mat_5d[j, i, c_vec[3 * ij]] = rr[ij, 0].T - c_mat_5d[j, i, c_vec[3 * ij + 1]] = rr[ij, 1].T - c_mat_5d[j, i, c_vec[3 * ij + 2]] = rr[ij, 2].T - - c_mat_4d[ij, c_vec[3 * ij]] = rr[ij, 0] - c_mat_4d[ij, c_vec[3 * ij + 1]] = rr[ij, 1] - c_mat_4d[ij, c_vec[3 * ij + 2]] = rr[ij, 2] + c_mat_5d[i, j, c_vec[ij]] = rr[ij] + c_mat_5d[j, i, c_vec[ij]] = rr[ij].transpose(0, 2, 1) + c_mat_4d[ij, c_vec[ij]] = rr[ij] # Compute estimates for the tuples {0.5*(Ri^TRi+Ri^TgkRi), k=1:3} for # i=1:N. For 1<=i,j<=N and c=1,2,3 write Qij^c=0.5*(Ri^TRj+Ri^TgmRj). @@ -1452,26 +1446,18 @@ def _sync_signs(self, rr, c_vec): # In C_2 one such matrix is constructed for the 3rd rows # and is rank 1 by construction. In practice, thus far, for each c and # (i,j) we either have Qij^c or -Qij^c independently. - c_mat = np.zeros((3, 3 * self.n_img, 3 * self.n_img), dtype=self.dtype) + c_mat = np.zeros((3, self.n_img, 3, self.n_img, 3), dtype=self.dtype) rot = np.zeros((self.n_img, 3, 3), dtype=self.dtype) for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): ij = self.pairs_to_linear[i, j] - c_mat[c_vec[3 * ij], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ij, 0] - c_mat[c_vec[3 * ij + 1], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ - ij, 1 - ] - c_mat[c_vec[3 * ij + 2], 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = rr[ - ij, 2 - ] - - c_mat[0] = c_mat[0] + c_mat[0].T - c_mat[1] = c_mat[1] + c_mat[1].T - c_mat[2] = c_mat[2] + c_mat[2].T + c_mat[c_vec[ij], i, :, j, :] = rr[ij] + + c_mat = c_mat + c_mat.transpose(0, 3, 4, 1, 2) for c in range(3): for i in range(self.n_img): - c_mat[c, 3 * i : 3 * i + 3, 3 * i : 3 * i + 3] = c_mat_5d[i, i, c] + c_mat[c, i, :, i, :] = c_mat_5d[i, i, c] # To decompose cMat as a rank 1 matrix we need to adjust the signs of the # Qij^c so that sign(Qij^c*Qjk^c) = sign(Qik^c) for all c=1,2,3 and (i,j). @@ -1628,15 +1614,12 @@ def _sync_signs(self, rr, c_vec): idx = 0 for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): - c_mat[c, 3 * j : 3 * j + 3, 3 * i : 3 * i + 3] = ( - signs[c, idx] * c_mat[c, 3 * j : 3 * j + 3, 3 * i : 3 * i + 3] - ) - c_mat[c, 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] = ( - signs[c, idx] * c_mat[c, 3 * i : 3 * i + 3, 3 * j : 3 * j + 3] - ) + c_mat[c, j, :, i, :] *= signs[c, idx] + c_mat[c, i, :, j, :] *= signs[c, idx] idx += 1 # cMat(:,:,c) are now rank 1. Decompose using SVD and take leading eigenvector. + c_mat = c_mat.reshape(3, 3 * self.n_img, 3 * self.n_img) U1, S1, _ = la.svds(c_mat[0], k=3, which="LM") U2, S2, _ = la.svds(c_mat[1], k=3, which="LM") U3, S3, _ = la.svds(c_mat[2], k=3, which="LM") From b7b53a83234404755cc8e68d736ea55aa4fa7f97 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 6 Aug 2024 15:18:55 -0400 Subject: [PATCH 067/105] more indexing cleanup --- src/aspire/abinitio/commonline_d2.py | 42 +++++++++++----------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 9871be3325..f6f405f209 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1447,7 +1447,6 @@ def _sync_signs(self, rr, c_vec): # and is rank 1 by construction. In practice, thus far, for each c and # (i,j) we either have Qij^c or -Qij^c independently. c_mat = np.zeros((3, self.n_img, 3, self.n_img, 3), dtype=self.dtype) - rot = np.zeros((self.n_img, 3, 3), dtype=self.dtype) for i in range(self.n_img - 1): for j in range(i + 1, self.n_img): ij = self.pairs_to_linear[i, j] @@ -1620,34 +1619,27 @@ def _sync_signs(self, rr, c_vec): # cMat(:,:,c) are now rank 1. Decompose using SVD and take leading eigenvector. c_mat = c_mat.reshape(3, 3 * self.n_img, 3 * self.n_img) - U1, S1, _ = la.svds(c_mat[0], k=3, which="LM") - U2, S2, _ = la.svds(c_mat[1], k=3, which="LM") - U3, S3, _ = la.svds(c_mat[2], k=3, which="LM") - svals2 = np.zeros((3, 3), dtype=self.dtype) - svals2[0] = S1[::-1] - svals2[1] = S2[::-1] - svals2[2] = S3[::-1] - - # Stable eigenvectors. - U1 = np.sign(U1[0]) * U1 - U2 = np.sign(U2[0]) * U2 - U3 = np.sign(U3[0]) * U3 + U1, _, _ = la.svds(c_mat[0], k=3, which="LM") + U2, _, _ = la.svds(c_mat[1], k=3, which="LM") + U3, _, _ = la.svds(c_mat[2], k=3, which="LM") + + # Stabilize and take leading eigenvector. + U1 = np.sign(U1[0, -1]) * U1[:, -1] + U2 = np.sign(U2[0, -1]) * U2[:, -1] + U3 = np.sign(U3[0, -1]) * U3[:, -1] # The c'th row of the rotation Rj is Uc(3*j-2:3*j,1)/norm(Uc(3*j-2:3*j,1)), # (Rows must be normalized to length 1). logger.info("Assembeling rows to rotations matrices.") - for i in range(self.n_img): - rot[i, 0] = U1[3 * i : 3 * i + 3, -1] / np.linalg.norm( - U1[3 * i : 3 * i + 3, -1] - ) - rot[i, 1] = U2[3 * i : 3 * i + 3, -1] / np.linalg.norm( - U2[3 * i : 3 * i + 3, -1] - ) - rot[i, 2] = U3[3 * i : 3 * i + 3, -1] / np.linalg.norm( - U3[3 * i : 3 * i + 3, -1] - ) - if np.linalg.det(rot[i]) < 0: - rot[i, 2] = -rot[i, 2] + rot = np.zeros((self.n_img, 3, 3), dtype=self.dtype) + rot[:, 0] = U1.reshape(self.n_img, 3) + rot[:, 1] = U2.reshape(self.n_img, 3) + rot[:, 2] = U3.reshape(self.n_img, 3) + rot /= np.linalg.norm(rot, axis=-1)[:, :, None] + + # Ensure we have rotations. + not_a_rot = np.argwhere(np.linalg.det(rot) < 0) + rot[not_a_rot, 2] *= -1 return rot From 8000740eeb4e365d4b44623d3fb3edb3b719eaf1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 6 Aug 2024 15:22:40 -0400 Subject: [PATCH 068/105] remove unused variable --- src/aspire/abinitio/commonline_d2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f6f405f209..8691f358f0 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -831,7 +831,6 @@ def _get_Rijs_from_oct(self, lin_idx, octant=1): n_rots2 = n_rots else: n_rots2 = len(self.sphere_grid2) - n_pairs = len(lin_idx) # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. p_idx, inplane_i, inplane_j = np.unravel_index( From 4bbc85d4cc7eeed2ffb5b61c1405802f184af710 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 6 Aug 2024 15:51:56 -0400 Subject: [PATCH 069/105] circ_seq value check --- src/aspire/abinitio/commonline_d2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 8691f358f0..d1b070763e 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1679,13 +1679,18 @@ def _mult_smat_by_vec(self, v, sign_mat, pairs_map): @staticmethod def _circ_seq(n1, n2, L): """ - Make a circular sequence of integers between n1 and n2 modulo L. + For integers 0 <= n1, n2 < L, make a circular sequence of integers between + n1 and n2 modulo L. :param n1: First integer in sequence. :param n2: Last integer in sequence. :param L: Modulus of values in sequence. :return: Circular sequence modulo L. """ + if min(n1, n2) < 0 or max(n1, n2) >= L: + raise ValueError( + f"n1 and n2 must both be in [0, {L}). Found n1={n1}, n2={n2}." + ) if n2 < n1: n2 += L if n1 == n2: From 3fa9c1b1465d14e3698188756d83697be8bb7260 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 8 Aug 2024 15:47:06 -0400 Subject: [PATCH 070/105] compute_scl_scores: Replace loop with vectorized operations. --- src/aspire/abinitio/commonline_d2.py | 77 +++++++++++++++------------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index d1b070763e..32815c1a3b 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -610,44 +610,47 @@ def _compute_scl_scores(self): scl_idx[non_eq_lin_idx].flatten(), (n_theta // 2, n_theta) ) + # Compute max correlation over all shifts. + corrs = np.real( + self.pf_shifted @ np.transpose(np.conj(self.pf_full), (0, 2, 1)) + ) + corrs = np.reshape(corrs, (self.n_img, self.n_shifts, n_theta // 2, n_theta)) + corrs = np.max(corrs, axis=1) + + # Map correlations to probabilities (in the spirit of Maximum Likelihood). + corrs = 0.5 * (corrs + 1) + + # Compute equator measures. + eq_measures = np.zeros((self.n_img, n_theta // 2), dtype=self.dtype) + for i in range(self.n_img): + eq_measures[i] = self._all_eq_measures(corrs[i]) + + # Handle the cases: Non-equator, Non-top-view equator images. + # 1. Non-equators: just take product of probabilities. corrs_out = np.zeros((n_img, M), dtype=self.dtype) - for i in trange(n_img): - pf_full_i = self.pf_full[i] - pf_i_shifted = self.pf_shifted[i] - - # Compute max correlation over all shifts. - corrs = np.real(pf_i_shifted @ np.conj(pf_full_i).T) - corrs = np.reshape(corrs, (self.n_shifts, n_theta // 2, n_theta)) - corrs = np.max(corrs, axis=0) - - # Map correlations to probabilities (in the spirit of Maximum Likelihood). - corrs = 0.5 * (corrs + 1) - - # Compute equator measures. - eq_measures = self._all_eq_measures(corrs) - - # Handle the cases: Non-equator, Non-top-view equator images. - # 1. Non-equators: just take product of probabilities. - prod_corrs = np.prod(corrs[non_eq_idx].reshape(n_non_eq, 3), axis=1) - corrs_out[i, non_eq_lin_idx] = prod_corrs - - # 2. Non-topview equators: adjust scores by eq_measures - for eq_idx in range(n_eq): - for j in range(n_inplane): - # Take the correlations for the self common line candidate of the - # "equator rotation" `eq_idx` with respect to image i, and - # multiply by all scores from the function eq_measures (see - # documentation inside the function ). Then take maximum over - # all the scores. - scl_idx_list = np.unravel_index( - self.scl_idx_lists[0, eq_idx, j], (n_theta // 2, n_theta) - ) - true_scls_corrs = corrs[scl_idx_list] - scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] - eq_measures_j = eq_measures[scls_cand_idx] - measures_agg = np.outer(true_scls_corrs, eq_measures_j) - k = self.non_tv_eq_idx[eq_idx] - corrs_out[i, k * n_inplane + j] = np.max(measures_agg) + prod_corrs = np.prod( + corrs[:, non_eq_idx[0], non_eq_idx[1]].reshape(self.n_img, n_non_eq, 3), + axis=2, + ) + corrs_out[:, non_eq_lin_idx] = prod_corrs + + # 2. Non-topview equators: adjust scores by eq_measures + for eq_idx in range(n_eq): + for j in range(n_inplane): + # Take the correlations for the self common line candidate of the + # "equator rotation" `eq_idx` with respect to image i, and + # multiply by all scores from the function eq_measures (see + # documentation inside the function ). Then take maximum over + # all the scores. + scl_idx_list = np.unravel_index( + self.scl_idx_lists[0, eq_idx, j], (n_theta // 2, n_theta) + ) + true_scls_corrs = corrs[:, scl_idx_list[0], scl_idx_list[1]] + scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] + eq_measures_j = eq_measures[:, scls_cand_idx] + measures_agg = true_scls_corrs[:, :, None] * eq_measures_j[:, None, :] + k = self.non_tv_eq_idx[eq_idx] + corrs_out[:, k * n_inplane + j] = np.max(measures_agg, axis=(-2, -1)) self.scls_scores = corrs_out From 4baa05631af5c8b7e0afeb8605f4d1209cae349d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 13 Aug 2024 11:07:00 -0400 Subject: [PATCH 071/105] Add doc io --- src/aspire/abinitio/commonline_d2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 32815c1a3b..d851767bda 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1409,6 +1409,12 @@ def _sync_signs(self, rr, c_vec): This function executes the final stage of the algorithm, Signs synchroniztion. At the end all rows of the rotations Ri are exctracted and the matrices Ri are assembled. + + :param rr: Array of color synchronized rotations' rows outer products of + shape (n_pairs, 3, 3, 3), where each rr[ij] corresponds to a 3-tuple + of m'th row outer product matrices, some of which having a spurious -1. + :param c_vec: A color mapping vector of length (n_pairs * 3) which permutes + the 3-tuples of `rr` to be globally row-consistent. """ logger.info("Performing signs synchronization.") # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according From 234ecee68c119d8a46fc08cc566cbb1b02745fbd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 13 Aug 2024 11:32:41 -0400 Subject: [PATCH 072/105] Documentation for mark_equators. --- src/aspire/abinitio/commonline_d2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index d851767bda..af0755f093 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1415,6 +1415,7 @@ def _sync_signs(self, rr, c_vec): of m'th row outer product matrices, some of which having a spurious -1. :param c_vec: A color mapping vector of length (n_pairs * 3) which permutes the 3-tuples of `rr` to be globally row-consistent. + :return: n_img x 3 x 3 array of rotation matrices. """ logger.info("Performing signs synchronization.") # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according @@ -1743,6 +1744,12 @@ def _saff_kuijlaars(N): @staticmethod def _mark_equators(sphere_grid, eq_filter_angle): """ + This method categorizes a set of 3D unit vectors into equator and non-equator + vectors determined by the parameter `eq_filter_angle`, returned as `eq_idx`. + It further categorizes the vectors into the classes non_equator, z-equator, + y-equator, x-equator, z-top_view, y-top_view, and x-top_view, which are labeled + respectively with the values 0 - 6 and returned as `eq_class`. + :param sphere_grid: Nx3 array of vertices in cartesian coordinates. :param eq_filter_angle: Angular distance from equator to be marked as an equator point. From 8a183602572f028073c69c847ac5285bf6fb8804 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 08:57:35 -0400 Subject: [PATCH 073/105] Clean up mark_equators --- src/aspire/abinitio/commonline_d2.py | 30 +++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index af0755f093..26258de789 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1763,23 +1763,12 @@ def _mark_equators(sphere_grid, eq_filter_angle): n_rots = len(sphere_grid) angular_dists = np.zeros((n_rots, 3), dtype=sphere_grid.dtype) - # Distance from z-axis equator. - proj_xy = sphere_grid.copy() - proj_xy[:, 2] = 0 - proj_xy /= np.linalg.norm(proj_xy, axis=1)[:, None] - angular_dists[:, 0] = np.sum(sphere_grid * proj_xy, axis=-1) - - # Distance from y-axis equator. - proj_xz = sphere_grid.copy() - proj_xz[:, 1] = 0 - proj_xz /= np.linalg.norm(proj_xz, axis=1)[:, None] - angular_dists[:, 1] = np.sum(sphere_grid * proj_xz, axis=-1) - - # Distance from x-axis equator. - proj_yz = sphere_grid.copy() - proj_yz[:, 0] = 0 - proj_yz /= np.linalg.norm(proj_yz, axis=1)[:, None] - angular_dists[:, 2] = np.sum(sphere_grid * proj_yz, axis=-1) + # For each grid point get the distance from the z, y, and x-axis equators. + for i in range(3): + proj_along_axis = sphere_grid.copy() + proj_along_axis[:, 2 - i] = 0 + proj_along_axis /= np.linalg.norm(proj_along_axis, axis=1)[:, None] + angular_dists[:, i] = np.sum(sphere_grid * proj_along_axis, axis=-1) # Mark all views close to an equator. eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) @@ -1795,9 +1784,14 @@ def _mark_equators(sphere_grid, eq_filter_angle): # 5 -> y top view # 6 -> x top view eq_class = np.zeros(n_rots) + + # Grid points which are equator points with respect to 2 equators are considered top views. + # For example, a grid point that is close to both the x and y equator is a z top view. top_view_idx = n_eqs > 1 - top_view_class = np.argmin(angular_dists[top_view_idx] > eq_min_dist) + top_view_class = np.argmin(angular_dists[top_view_idx] > eq_min_dist, axis=1) eq_class[top_view_idx] = top_view_class + 4 + + # Assign grid points which are equator points with respect to only 1 equator. eq_view_idx = n_eqs == 1 eq_view_class = np.argmax(angular_dists[eq_view_idx] > eq_min_dist, axis=1) eq_class[eq_view_idx] = eq_view_class + 1 From 55d6897ba45e3c766c33b7bf2944b86b821537e7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 10:32:29 -0400 Subject: [PATCH 074/105] Add documentation to _generate_commonline_indices. --- src/aspire/abinitio/commonline_d2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 26258de789..f6d6766a52 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1829,7 +1829,7 @@ def _generate_inplane_rots(sphere_grid, d_theta): # Generate in-plane rotations. d_theta *= np.pi / 180 - # TODO: Negative signs to match matlab. + # Negative signs to match matlab. inplane_rots = Rotation.about_axis( "z", np.arange(0, -2 * np.pi, -d_theta), dtype=dtype ).matrices @@ -1844,8 +1844,12 @@ def _generate_inplane_rots(sphere_grid, d_theta): def _generate_commonline_indices(self, cl_angles): """ - Converts pairs pf commonline angles in [0, 360) first into polar Fourier - indices in [0, self.n_theta), then into commonline linear indices. + Converts a multi-dimensional stack of pairs of commonline angles in [0, 360) degrees + into a flattened stack of polar Fourier linear indices, with the convention that + each linear index corresponds to an unraveled index in [0, n_theta // 2) x [0, n_theta). + + :param cl_angles: A multi-dimensional stack of commonline angles in degrees, shape (..., 2). + :return: cl_idx, a 1D array of linear indices. """ L = self.n_theta @@ -1865,8 +1869,4 @@ def _generate_commonline_indices(self, cl_angles): # Convert to linear indices in 180x360 correlation matrix. cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(L // 2, L)) - # Return cl_angles in original shape as integer indices. - cl_angles = np.rint(cl_angles.reshape(og_shape)).astype(int) - - # Return as integer indices. return cl_idx From 1ff770f962ff8766a493b830d234dfec1ec7d978 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 10:50:51 -0400 Subject: [PATCH 075/105] doubles for expensive testing --- tests/test_orient_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 6e8f2741ee..c319be4acc 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -16,7 +16,7 @@ # Parameters # ############## -DTYPE = [np.float64, pytest.param(np.float32, marks=pytest.mark.expensive)] +DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] RESOLUTION = [48, 49] N_IMG = [10] OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] From c679c0c09ae773a9bea1f206f8e6031efed3f520 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 10:58:07 -0400 Subject: [PATCH 076/105] Add detail to test comment about candidate rotation parameters. --- tests/test_orient_d2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index c319be4acc..e83fcd37c7 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -23,6 +23,10 @@ # Since these tests are optimized for runtime, detuned parameters cause # the algorithm to be fickle, especially for small problem sizes. +# In particular, the parameters `grid_res`, inplane_res`, and `eq_min_dist` +# which control the number of candidate rotations used in the D2 algorithm +# will produce bad estimates if the candidates do not align closely with the +# ground truth rotations. # This seed is chosen so the tests pass CI on github's envs as well # as our self-hosted runner. SEED = 3 From 011d21800942d467d4e4230c956a21bd38bd1887 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 11:39:57 -0400 Subject: [PATCH 077/105] cache source in test --- tests/test_orient_d2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index e83fcd37c7..dee2f25d8e 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -71,6 +71,7 @@ def source(n_img, resolution, dtype, offsets): amplitudes=1, seed=SEED, ) + src = src.cache() # Precompute image stack return src From 42677ff26947b2389d1ac1d4c079f3138a02415f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 12:04:55 -0400 Subject: [PATCH 078/105] Use randomly ordered Rijs for J-sync tests --- tests/test_orient_d2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index dee2f25d8e..4d79d37b0f 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -178,9 +178,9 @@ def test_global_J_sync(orient_est): rots = orient_est.src.rotations Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): - for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of Rijs - Rijs[p, k] = rots[i].T @ g @ rots[j] + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij # J-conjugate a random set of Rijs. Rijs_conj = Rijs.copy() @@ -245,9 +245,9 @@ def test_global_J_sync_single_triplet(dtype): rots = orient_est.src.rotations Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): - for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of Rijs - Rijs[p, k] = rots[i].T @ g @ rots[j] + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij # J-conjugate a random Rij. Rijs_conj = Rijs.copy() From 0f025af248f8a310377d93e5485e33bf75dd9a06 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Aug 2024 16:04:47 -0400 Subject: [PATCH 079/105] randomize color sync test --- tests/test_orient_d2.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 4d79d37b0f..7e997fa8b2 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -268,10 +268,20 @@ def test_sync_colors(orient_est): # Grab set of rotations and generate a set of relative rotations, Rijs. rots = orient_est.src.rotations Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + gt_colors = np.zeros((orient_est.n_pairs, 3), dtype=int) for p, (i, j) in enumerate(orient_est.pairs): - for k, g in enumerate(orient_est.gs): - k = (k + p) % 4 # Mix up the ordering of Rijs - Rijs[p, k] = rots[i].T @ g @ rots[j] + gs = orient_est.gs + if p > 0: + np.random.shuffle(gs) # Mix up the ordering of all but 1st Rijs + + # Compute the rotation row permutation created by the ordering of gs. + # See Proposition 5.1 in the related publication for details. + for m in range(3): + gt_colors[p, m] = np.argmax(np.sum(abs(0.5 * (gs[0] + gs[m + 1])), axis=0)) + + # Compute Rijs with shuffled gs. + Rij = rots[i].T @ gs @ rots[j] + Rijs[p] = Rij # Perform color synchronization. # Rijs_rows is shape (n_pairs, 3, 3, 3) where Rijs_rows[ij, m] corresponds @@ -284,7 +294,8 @@ def test_sync_colors(orient_est): vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): for m in range(3): - vijs[p, m] = np.outer(rots[i][m], rots[j][m]) + row = gt_colors[p, m] + vijs[p, m] = np.outer(rots[i][row], rots[j][row]) # Reshape `colors` to shape (n_pairs, 3) and use to index Rijs_rows into the # correctly order 3rd row outer products vijs. @@ -303,17 +314,14 @@ def test_sync_colors(orient_est): # Apply this mapping to all rows of the colors array colors_mapped = mapping[colors] - # Synchronize Rijs_rows according to the color map. - row_indices = np.arange(orient_est.n_pairs)[:, None] - Rijs_rows_synced = Rijs_rows[row_indices, colors_mapped] + # Check that remapped color permutations match ground truth. + np.testing.assert_allclose(colors_mapped, gt_colors) # Rijs_rows_synced should match the ground truth vijs up to the sign of each row. # So we multiply by the sign of the first column of the last two axes to sync signs. vijs = vijs * np.sign(vijs[..., 0])[..., None] - Rijs_rows_synced = Rijs_rows_synced * np.sign(Rijs_rows_synced[..., 0])[..., None] - np.testing.assert_allclose( - vijs, Rijs_rows_synced, atol=utest_tolerance(orient_est.dtype) - ) + Rijs_rows = Rijs_rows * np.sign(Rijs_rows[..., 0])[..., None] + np.testing.assert_allclose(vijs, Rijs_rows, atol=utest_tolerance(orient_est.dtype)) # Check dtype pass-through. assert Rijs_rows.dtype == orient_est.dtype From 0551654e495066873a8c90254dbbe6cbd9844803 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 15 Aug 2024 09:07:36 -0400 Subject: [PATCH 080/105] Add documentation to color_sync test. --- tests/test_orient_d2.py | 58 +++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 7e997fa8b2..855d80cfa2 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -265,6 +265,24 @@ def test_global_J_sync_single_triplet(dtype): def test_sync_colors(orient_est): + """ + A set of estimated relative rotations, Rijs, have the shape (n_pairs, 4, 3, 3), + where each 4-tuple Rij is given by Rij = Ri.T @ g_m @ Rj, for m in [0, 1, 2, 3], + where each g_m is an element of the D2 symmetry group. The ordering of the symmetry + group elements, g_m, is unknown and independent between Rijs. The `_sync_colors` + algorithm forms the set of vijs of shape (n_pairs, 3, 3, 3), where each vij, given + by vij = (Rij[0] + Rij[m]) / 2 with m = 1, 2, 3, is some permutation of the outer + products of the k'th rows of the rotation matrices Ri and Rj, for k = 0, 1, 2. + + The 'sync_colors` algorithm uses a colored graph to partition the set of vijs + based on k'th row outer products and returns those outer products along with + a color mapping encoding a permutation for each vij. + + In this test we form a set of Rijs with randomly ordered symmetry group elements + and extract the ground truth color permutations based on that ordering. We then + construct a set of ground truth vijs adjusted by the ground truth color permuations. + We then compare estimated vijs and color permutations to ground truth. + """ # Grab set of rotations and generate a set of relative rotations, Rijs. rots = orient_est.src.rotations Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) @@ -283,13 +301,6 @@ def test_sync_colors(orient_est): Rij = rots[i].T @ gs @ rots[j] Rijs[p] = Rij - # Perform color synchronization. - # Rijs_rows is shape (n_pairs, 3, 3, 3) where Rijs_rows[ij, m] corresponds - # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row - # of the rotations matrices Ri and Rj. `colors` partitions the set of Rijs_rows - # such that the indices of `colors` corresponds to the row index m. - colors, Rijs_rows = orient_est._sync_colors(Rijs) - # Compute ground truth m'th row outer products. vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) for p, (i, j) in enumerate(orient_est.pairs): @@ -297,34 +308,41 @@ def test_sync_colors(orient_est): row = gt_colors[p, m] vijs[p, m] = np.outer(rots[i][row], rots[j][row]) - # Reshape `colors` to shape (n_pairs, 3) and use to index Rijs_rows into the + # Perform color synchronization. + # `est_vijs` is shape (n_pairs, 3, 3, 3) where est_vijs[ij, m] corresponds + # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row + # of the rotations matrices Ri and Rj. `est_colors` partitions the set of `est_vijs` + # such that the indices of `est_colors` corresponds to the row index m. + est_colors, est_vijs = orient_est._sync_colors(Rijs) + + # Reshape `est_colors` to shape (n_pairs, 3) and use to index est_vijs into the # correctly order 3rd row outer products vijs. - colors = colors.reshape(orient_est.n_pairs, 3) + est_colors = est_colors.reshape(orient_est.n_pairs, 3) - # `colors` is an arbitrary permutation (but globally consistent), and we know - # that colors[0] should correspond to the ordering [0, 1, 2] due to the construction + # `est_colors` is an arbitrary permutation (but globally consistent), and we know + # that est_colors[0] should correspond to the ordering [0, 1, 2] due to the construction # of Rijs[0] using the symmetric rotations g0, g1, g2, g3 in non-permuted order. - # So we sort the columns such that colors[0] = [0,1,2]. + # So we sort the columns such that est_colors[0] = [0,1,2]. # Create a mapping array - perm = colors[0] + perm = est_colors[0] mapping = np.zeros_like(perm) mapping[perm] = np.arange(3) - # Apply this mapping to all rows of the colors array - colors_mapped = mapping[colors] + # Apply this mapping to all rows of the est_colors array + est_colors_mapped = mapping[est_colors] # Check that remapped color permutations match ground truth. - np.testing.assert_allclose(colors_mapped, gt_colors) + np.testing.assert_allclose(est_colors_mapped, gt_colors) - # Rijs_rows_synced should match the ground truth vijs up to the sign of each row. + # est_vijs_synced should match the ground truth vijs up to the sign of each row. # So we multiply by the sign of the first column of the last two axes to sync signs. vijs = vijs * np.sign(vijs[..., 0])[..., None] - Rijs_rows = Rijs_rows * np.sign(Rijs_rows[..., 0])[..., None] - np.testing.assert_allclose(vijs, Rijs_rows, atol=utest_tolerance(orient_est.dtype)) + est_vijs = est_vijs * np.sign(est_vijs[..., 0])[..., None] + np.testing.assert_allclose(vijs, est_vijs, atol=utest_tolerance(orient_est.dtype)) # Check dtype pass-through. - assert Rijs_rows.dtype == orient_est.dtype + assert est_vijs.dtype == orient_est.dtype def test_sync_signs(orient_est): From 4f05efaecef9ae7152a0ce4671d34af34976a201 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 26 Aug 2024 10:17:25 -0400 Subject: [PATCH 081/105] remove F-order flatten --- src/aspire/abinitio/commonline_d2.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f6d6766a52..231cd7a64d 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -531,7 +531,7 @@ def _generate_scl_scores_idx_map(self): # First the map for i Date: Mon, 26 Aug 2024 14:36:27 -0400 Subject: [PATCH 082/105] Reword docstring --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 231cd7a64d..66c6debadb 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -653,7 +653,7 @@ def _compute_scl_scores(self): def _all_eq_measures(self, corrs): """ - Compute a measure of how much an image from data is close to an equator. + Compute a measure indicating how likely an image is an equator image. :param corrs: Correlation table of shape (n_theta // 2, n_theta). From 750245040001a110bfe649c2156286a89e267c06 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 27 Aug 2024 10:49:20 -0400 Subject: [PATCH 083/105] Add docstrings --- src/aspire/abinitio/commonline_d2.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 66c6debadb..7a3d3b1652 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -580,7 +580,6 @@ def _generate_scl_scores_idx_map(self): tmp2 = oct2_ij_map[:, :, 1].flatten() self.oct2_ij_map = np.column_stack((tmp1, tmp2)) - ############################################## # Compute Self-Commonline Correlation Scores # ############################################## @@ -751,7 +750,7 @@ def _compute_cl_scores(self): desc="Searching for commonlines between pairs of images", total=n_pairs ) - # Vectorize over pairs of images + # For each i'th image compute the correlation with all j'th images, j > i. for i in range(self.n_img - 1): pf_i = self.pf_shifted[i] scores_i = self.scls_scores[i] @@ -766,7 +765,6 @@ def _compute_cl_scores(self): corrs = np.max(corrs, axis=1) # Max over shifts # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. - # Vectorize extraction and processing of correlations prod_corrs = corrs[:, cl_idx[0], cl_idx[1]] prod_corrs = prod_corrs.reshape(n_pf_js, len(prod_corrs[0]) // 4, 4) prod_corrs = np.prod(prod_corrs, axis=2) @@ -815,6 +813,16 @@ def _get_Rijs_from_lin_idx(self, lin_idx): return Rijs_est def _get_Rijs_from_oct(self, lin_idx, octant=1): + """ + Calculate estimated relative rotations Rijs from the linear indices of + common-lines estimates from the search table. Rijs are generated from the + rotation grids from which the common-lines table was generated. + + :param lin_idx: Set of linear indices corresponding to best estimate of Rijs. + :param octant: Octant of rotation grid from which the Rj rotation was selected + when generating the common-lines table. + :return: Estimated Rijs. + """ if octant not in [1, 2]: raise ValueError("`octant` must be 1 or 2.") @@ -1128,6 +1136,16 @@ def _sync_colors(self, Rijs): return cp, Rijs_rows def _match_colors(self, Rijs_rows): + """ + Partition the set of matrices Rijs_rows, which correspond to a permutation of + the outer products of the m'th rows of Ri and Rj, into 3 sets of matrices each + corresponding to an m'th row. Returns the permutations which induce the partition. + + :param Rijs_rows: An n_pairsx3x3x3 array of m'th row outer products for the pairs + Ri, Rj, where Rijs_rows[:, i] is the m'th row outer product of unknown row m. + :return: n_pairs length array corresponding to the permutation which color matches + Rijs_rows. + """ Rijs_rows_t = np.transpose(Rijs_rows, (0, 1, 3, 2)) trip_perms = np.array( [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]], From 7e1d66ebaff107324eef4ba58f39d5579e93621c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 5 Sep 2024 11:00:15 -0400 Subject: [PATCH 084/105] Use number rays in 2 degrees instead 2 pf rays in all_eq_measures. --- src/aspire/abinitio/commonline_d2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 7a3d3b1652..099827b29d 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -698,7 +698,9 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = 2 # Search radius within 2 adjacent rays of normal ray. + r = np.ceil(2 * L / 360).astype( + int + ) # Search radius within 2 degrees of normal ray. # Generate indices for normal to scl index. normal_2_scl_idx_0 = ( From 400c5ec9b835491f20544157cfd492b9e4af6daa Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 5 Sep 2024 11:03:13 -0400 Subject: [PATCH 085/105] Add comment about necessary ordering of D2 symmetry group elements. --- src/aspire/abinitio/commonline_d2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 099827b29d..1cd6190c03 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -80,6 +80,7 @@ def __init__( # D2 symmetry group. # Rearrange in order Identity, about_x, about_y, about_z. + # This ordering is necessary for reproducing MATLAB code results. self.gs = DnSymmetryGroup(order=2, dtype=self.dtype).matrices[[0, 3, 2, 1]] def estimate_rotations(self): From a331001ebc2bca9b407404d4e9333c63467e2f23 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 5 Sep 2024 11:26:14 -0400 Subject: [PATCH 086/105] black --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 1cd6190c03..0abc0488f8 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -80,7 +80,7 @@ def __init__( # D2 symmetry group. # Rearrange in order Identity, about_x, about_y, about_z. - # This ordering is necessary for reproducing MATLAB code results. + # This ordering is necessary for reproducing MATLAB code results. self.gs = DnSymmetryGroup(order=2, dtype=self.dtype).matrices[[0, 3, 2, 1]] def estimate_rotations(self): From 9181d413b48463532af7ab5b19b23af9676d9070 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 5 Sep 2024 12:09:41 -0400 Subject: [PATCH 087/105] revert all_eq_measures search radius. --- src/aspire/abinitio/commonline_d2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 0abc0488f8..ae820d9790 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -699,9 +699,7 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = np.ceil(2 * L / 360).astype( - int - ) # Search radius within 2 degrees of normal ray. + r = 2 # Search radius within 2 adjacent rays of normal ray. # Generate indices for normal to scl index. normal_2_scl_idx_0 = ( From 8cf8bf9cbcadbf7260bf9c5773a82641e0c02602 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 5 Sep 2024 12:11:38 -0400 Subject: [PATCH 088/105] black --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index ae820d9790..95d5d5db59 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -699,7 +699,7 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = 2 # Search radius within 2 adjacent rays of normal ray. + r = 2 # Search radius within 2 adjacent rays of normal ray. # Generate indices for normal to scl index. normal_2_scl_idx_0 = ( From 7aeaf829864a4a4705d288ade422dadc331c154c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 10 Sep 2024 11:10:49 -0400 Subject: [PATCH 089/105] Add seed to sporadically failing test. --- tests/test_orient_d2.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 855d80cfa2..25b193a629 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -5,6 +5,7 @@ from aspire.source import Simulation from aspire.utils import ( J_conjugate, + Random, Rotation, all_pairs, mean_aligned_angular_distance, @@ -287,19 +288,23 @@ def test_sync_colors(orient_est): rots = orient_est.src.rotations Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) gt_colors = np.zeros((orient_est.n_pairs, 3), dtype=int) - for p, (i, j) in enumerate(orient_est.pairs): - gs = orient_est.gs - if p > 0: - np.random.shuffle(gs) # Mix up the ordering of all but 1st Rijs - - # Compute the rotation row permutation created by the ordering of gs. - # See Proposition 5.1 in the related publication for details. - for m in range(3): - gt_colors[p, m] = np.argmax(np.sum(abs(0.5 * (gs[0] + gs[m + 1])), axis=0)) - # Compute Rijs with shuffled gs. - Rij = rots[i].T @ gs @ rots[j] - Rijs[p] = Rij + with Random(1234): + for p, (i, j) in enumerate(orient_est.pairs): + gs = orient_est.gs + if p > 0: + np.random.shuffle(gs) # Mix up the ordering of all but 1st Rijs. + + # Compute the rotation row permutation created by the ordering of gs. + # See Proposition 5.1 in the related publication for details. + for m in range(3): + gt_colors[p, m] = np.argmax( + np.sum(abs(0.5 * (gs[0] + gs[m + 1])), axis=0) + ) + + # Compute Rijs with shuffled gs. + Rij = rots[i].T @ gs @ rots[j] + Rijs[p] = Rij # Compute ground truth m'th row outer products. vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) From e61695a05198198afbe43b2e5a558128a2ceb6b6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 10 Sep 2024 11:33:05 -0400 Subject: [PATCH 090/105] revert search radius --- src/aspire/abinitio/commonline_d2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 95d5d5db59..0abc0488f8 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -699,7 +699,9 @@ def _all_eq_measures(self, corrs): # between one Fourier ray of the normal to a self common line candidate t_i # with its anti-podal as an additional way to measure if the image is an # equator and t_i+0.5*pi is the normal to its self common line. - r = 2 # Search radius within 2 adjacent rays of normal ray. + r = np.ceil(2 * L / 360).astype( + int + ) # Search radius within 2 degrees of normal ray. # Generate indices for normal to scl index. normal_2_scl_idx_0 = ( From a5bc37d4de559cc3b3c558fb6b45acafeab08a29 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 13 Sep 2024 09:22:55 -0400 Subject: [PATCH 091/105] typo --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 0abc0488f8..f7e0a86e87 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -49,7 +49,7 @@ def __init__( :param grid_res: Number of sampling points on sphere for projetion directions. These are generated using the Saaf-Kuijlaars algorithm. Default value is 1200. :param inplane_res: The sampling resolution of in-plane rotations for each - projetion direction. Default value is 5 degrees. + projection direction. Default value is 5 degrees. :param eq_min_dist: Width of strip around equator projection directions from which we do not sample directions. Default value is 7 degrees. :param epsilon: Tolerance for J-synchronization power method. From 7bab705f8d324eb5ee1bfe4b0491fb83a4346071 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 13 Sep 2024 11:23:03 -0400 Subject: [PATCH 092/105] Use np.nonzero instead of np.where. --- src/aspire/abinitio/commonline_d2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index f7e0a86e87..4c43bdcff5 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -272,7 +272,7 @@ def _generate_commonline_angles( cl_angles = np.zeros((2, n_pairs, n_theta, n_theta // 2, 4, 2)) for i in range(n_rots_i): - unique_pairs_i = np.where(eq2eq_Rij_table[i])[0] + unique_pairs_i = np.nonzero(eq2eq_Rij_table[i])[0] if len(unique_pairs_i) == 0: continue Ri = Ris[i] @@ -355,8 +355,8 @@ def _generate_scl_lookup_data(self): non_eq_idx[:, 0] = ( np.hstack( ( - np.where(self.eq_class1 == 0)[0], - len(self.eq_class1) + np.where(self.eq_class2 == 0)[0], + np.nonzero(self.eq_class1 == 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 == 0)[0], ) ) * self.n_inplane_rots @@ -368,8 +368,8 @@ def _generate_scl_lookup_data(self): # Non-topview equator indices. self.non_tv_eq_idx = np.concatenate( ( - np.where(self.eq_class1 > 0)[0], - len(self.eq_class1) + np.where(self.eq_class2 > 0)[0], + np.nonzero(self.eq_class1 > 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 > 0)[0], ) ) @@ -485,7 +485,7 @@ def _generate_scl_indices(self, scl_angles, eq_class): # As indicated above for equator candidate, for each self common line we # don't get a single coordinate but a range of them. Here we register a # list of coordinates for each such self common line candidate. - non_top_view_eq_idx = np.where(eq_class > 0)[0] + non_top_view_eq_idx = np.nonzero(eq_class > 0)[0] n_eq = len(non_top_view_eq_idx) n_inplane_rots = scl_angles.shape[1] count_eq = 0 From fcf92fa68bed03cff7dcfde287cff417e45356a5 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 13 Sep 2024 15:23:08 -0400 Subject: [PATCH 093/105] lowercase --- tests/test_orient_d2.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 25b193a629..e9dbae5b0b 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -79,7 +79,7 @@ def source(n_img, resolution, dtype, offsets): @pytest.fixture(scope="module") def orient_est(source): - return build_CL_from_source(source) + return build_cl_from_source(source) ######### @@ -140,22 +140,22 @@ def test_scl_scores(orient_est): ) # Initialize CL instance with new source. - CL = build_CL_from_source(src) + cl = build_cl_from_source(src) # Generate lookup data. - CL._compute_shifted_pf() - CL._generate_lookup_data() - CL._generate_scl_lookup_data() + cl._compute_shifted_pf() + cl._generate_lookup_data() + cl._generate_scl_lookup_data() # Compute self-commonline scores. - CL._compute_scl_scores() + cl._compute_scl_scores() - # CL.scls_scores is shape (n_img, n_cand_rots). Since we used the first + # cl.scls_scores is shape (n_img, n_cand_rots). Since we used the first # 10 candidate rotations of the first non-equator viewing direction as our # Simulation rotations, the maximum correlation for image i should occur at - # candidate rotation index (non_eq_idx * CL.n_inplane_rots + i). - max_corr_idx = np.argmax(CL.scls_scores, axis=1) - gt_idx = non_eq_idx * CL.n_inplane_rots + np.arange(10) + # candidate rotation index (non_eq_idx * cl.n_inplane_rots + i). + max_corr_idx = np.argmax(cl.scls_scores, axis=1) + gt_idx = non_eq_idx * cl.n_inplane_rots + np.arange(10) # Check that self-commonline indices match ground truth. n_match = np.sum(max_corr_idx == gt_idx) @@ -165,7 +165,7 @@ def test_scl_scores(orient_est): np.testing.assert_array_less(match_tol, n_match / src.n) # Check dtype pass-through. - assert CL.scls_scores.dtype == orient_est.dtype + assert cl.scls_scores.dtype == orient_est.dtype def test_global_J_sync(orient_est): @@ -240,7 +240,7 @@ def test_global_J_sync_single_triplet(dtype): """ # Generate 3 image source and orientation object. src = Simulation(n=3, L=10, dtype=dtype, seed=SEED) - orient_est = build_CL_from_source(src) + orient_est = build_cl_from_source(src) # Grab set of rotations and generate a set of relative rotations, Rijs. rots = orient_est.src.rotations @@ -455,7 +455,7 @@ def g_sync_d2(rots, rots_gt): return rots_gt_sync -def build_CL_from_source(source): +def build_cl_from_source(source): # Search for common lines over less shifts for 0 offsets. max_shift = 0 shift_step = 1 From b3c6bedf7cb5ef04bd189997481d471d54567841 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 13 Sep 2024 15:25:15 -0400 Subject: [PATCH 094/105] use ints --- src/aspire/abinitio/commonline_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 4c43bdcff5..2f0f29f24a 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1167,7 +1167,7 @@ def _match_colors(self, Rijs_rows): ) m = np.zeros((6, 6), dtype=self.dtype) - colors_i = np.zeros((len(self.triplets), 3), dtype=self.dtype) # ints? + colors_i = np.zeros((len(self.triplets), 3), dtype=int) n_trip = len(self.triplets) votes = np.zeros((n_trip)) trip_idx = 0 From b20e1e032da3cea26ef22c6dad67f004de386aa9 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 16 Sep 2024 11:24:45 -0400 Subject: [PATCH 095/105] input/output docs --- src/aspire/abinitio/commonline_d2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 2f0f29f24a..aa2ec66a5c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1097,6 +1097,15 @@ def _sync_colors(self, Rijs): The color sync procedure partitions the set of 3-tuples of m'th row outer products into 3 sets of row-consistent outer products up to the sign of each. + + :param Rijs: Array of shape (n_pairs,4,3,3) consisting of the n_pairs of + hand-consistent 4-tuples of Rijs. + :returns: + - cp, A color mapping vector of length (n_pairs * 3) which permutes + the 3-tuples of `Rijs_rows` to be globally row-consistent. + - Rijs_rows, An array of color synchronized rotations' rows outer products of + shape (n_pairs, 3, 3, 3), where each Rijs_rows[ij] corresponds to a 3-tuple + of m'th row outer product matrices, some of which having a spurious -1. """ logger.info("Performing rotations' rows synchronization.") # Generate array of one rank matrices from which we can extract rows. From e8c2221ba79f26532f72907b90042d4ce9071e36 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 17 Sep 2024 14:34:36 -0400 Subject: [PATCH 096/105] break up _sync_signs --- src/aspire/abinitio/commonline_d2.py | 121 ++++++++++++++++++++------- 1 file changed, 89 insertions(+), 32 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index aa2ec66a5c..29a45ff3ec 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1299,10 +1299,10 @@ def _mult_cmat_by_vec(self, c_perms, v): trip_idx = 0 for i in range(self.n_img): for j in range(i + 1, self.n_img - 1): - ij = 3 * self.pairs_to_linear[i, j] + ij_block = 3 * self.pairs_to_linear[i, j] for k in range(j + 1, self.n_img): - ik = 3 * self.pairs_to_linear[i, k] - jk = 3 * self.pairs_to_linear[j, k] + ik_block = 3 * self.pairs_to_linear[i, k] + jk_block = 3 * self.pairs_to_linear[j, k] # Extract permutation indices from c_perms n = c_perms[trip_idx] @@ -1319,36 +1319,36 @@ def _mult_cmat_by_vec(self, c_perms, v): # Multiply vector by color matrix # Upper triangular part - p = t_perms[p_n1] + ik - out[ij] = out[ij] - v[p[1]] - v[p[2]] + v[p[0]] - out[ij + 1] = out[ij + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[ij + 2] = out[ij + 2] - v[p[0]] - v[p[1]] + v[p[2]] + p = t_perms[p_n1] + ik_block + out[ij_block] = out[ij_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[ij_block + 1] = out[ij_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[ij_block + 2] = out[ij_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] - p = t_perms[p_n2] + jk - out[ij] = out[ij] - v[p[1]] - v[p[2]] + v[p[0]] - out[ij + 1] = out[ij + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[ij + 2] = out[ij + 2] - v[p[0]] - v[p[1]] + v[p[2]] + p = t_perms[p_n2] + jk_block + out[ij_block] = out[ij_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[ij_block + 1] = out[ij_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[ij_block + 2] = out[ij_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] - p = i_perms[p_n3] + jk - out[ik] = out[ik] - v[p[1]] - v[p[2]] + v[p[0]] - out[ik + 1] = out[ik + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[ik + 2] = out[ik + 2] - v[p[0]] - v[p[1]] + v[p[2]] + p = i_perms[p_n3] + jk_block + out[ik_block] = out[ik_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[ik_block + 1] = out[ik_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[ik_block + 2] = out[ik_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] # Lower triangular part - p = i_perms[p_n1] + ij - out[ik] = out[ik] - v[p[1]] - v[p[2]] + v[p[0]] - out[ik + 1] = out[ik + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[ik + 2] = out[ik + 2] - v[p[0]] - v[p[1]] + v[p[2]] - - p = i_perms[p_n2] + ij - out[jk] = out[jk] - v[p[1]] - v[p[2]] + v[p[0]] - out[jk + 1] = out[jk + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[jk + 2] = out[jk + 2] - v[p[0]] - v[p[1]] + v[p[2]] - - p = t_perms[p_n3] + ik - out[jk] = out[jk] - v[p[1]] - v[p[2]] + v[p[0]] - out[jk + 1] = out[jk + 1] - v[p[0]] - v[p[2]] + v[p[1]] - out[jk + 2] = out[jk + 2] - v[p[0]] - v[p[1]] + v[p[2]] + p = i_perms[p_n1] + ij_block + out[ik_block] = out[ik_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[ik_block + 1] = out[ik_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[ik_block + 2] = out[ik_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] + + p = i_perms[p_n2] + ij_block + out[jk_block] = out[jk_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[jk_block + 1] = out[jk_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[jk_block + 2] = out[jk_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] + + p = t_perms[p_n3] + ik_block + out[jk_block] = out[jk_block] - v[p[1]] - v[p[2]] + v[p[0]] + out[jk_block + 1] = out[jk_block + 1] - v[p[0]] - v[p[2]] + v[p[1]] + out[jk_block + 2] = out[jk_block + 2] - v[p[0]] - v[p[1]] + v[p[2]] return out def _unmix_colors(self, color_vecs): @@ -1434,8 +1434,14 @@ def R_theta(theta): def _sync_signs(self, rr, c_vec): """ This function executes the final stage of the algorithm, Signs - synchroniztion. At the end all rows of the rotations Ri are exctracted - and the matrices Ri are assembled. + synchroniztion. At this point, we have rotation rows + rr[ij, m] = sij_m * vi_m.T @ vj_m, where vi_m, vj_m are the m'th rows + of rotation matrices Ri and Rj and sij_m is an unknown sign. This method + uses the permutation vector, `c_vec`, to partition the rotation row + outer products and constructs a symmetric block matrix, H, with ij'th block + sij * vi.T @ vj. The signs sij are then adjusted so that H is rank-1. This + matrix is then factored to extract the rows of each rotation matrix. At the + end all rows of the rotations Ri are exctracted and the matrices Ri are assembled. :param rr: Array of color synchronized rotations' rows outer products of shape (n_pairs, 3, 3, 3), where each rr[ij] corresponds to a 3-tuple @@ -1445,6 +1451,28 @@ def _sync_signs(self, rr, c_vec): :return: n_img x 3 x 3 array of rotation matrices. """ logger.info("Performing signs synchronization.") + c_mat, c_mat_5d, c_mat_4d = self._construct_color_mats(rr, c_vec) + + sync_signs2 = self._compute_signs(c_mat_5d, c_mat_4d) + + rows_arr = self._estimate_rows(sync_signs2, c_mat_5d) + + signs = self._compute_signs_adjustment(rows_arr) + + rots = self._extract_rotations(c_mat, signs) + + return rots + + def _construct_color_mats(self, rr, c_vec): + """ + Construct the partitioned row synchronized color matrices, `c_mat`, where + c_mat[m] contains the 3x3 blocks sij*vi_m.T @ vj_m, where vi_m is the m'th + row of the i'th rotation Ri and sij is the unknown sign. + + :param rr: Non-partitioned rotation row matrices. + :param c_vec: Color partition vector. + :return: Partitioned row synchronized color matrices. + """ # Partition the union of tuples {0.5*(Ri^TRj+Ri^TgkRj), k=1:3} according # to the color partition established in color synchronization procedure. # The partition is stored in two different arrays each with the purpose @@ -1494,6 +1522,12 @@ def _sync_signs(self, rr, c_vec): for i in range(self.n_img): c_mat[c, i, :, i, :] = c_mat_5d[i, i, c] + return c_mat, c_mat_5d, c_mat_4d + + def _compute_signs(self, c_mat_5d, c_mat_4d): + """ + Compute signs for adjusting `c_mat` to be composed of all rank-1 3x3 blocks. + """ # To decompose cMat as a rank 1 matrix we need to adjust the signs of the # Qij^c so that sign(Qij^c*Qjk^c) = sign(Qik^c) for all c=1,2,3 and (i,j). # In practice we compare the sign of the sum of the entries of Qij^c*Qjk^c @@ -1552,6 +1586,14 @@ def _sync_signs(self, rr, c_vec): ) # The function (1-x)/2 maps 1->0 and -1->1 + return sync_signs2 + + def _estimate_rows(self, sync_signs2, c_mat_5d): + """ + Construct 3N x 3N matrix of rank-1 3x3 blocks of sij*vi_m.T @ vj_m, + the leading eigenvectors of which correspond to estimates for the rows + of the rotations Ri, up to signs. + """ c_mat_5d_mp = np.concatenate((c_mat_5d, -c_mat_5d), axis=1) rows_arr = np.zeros((3, self.n_img, 3 * self.n_img), dtype=self.dtype) svals = np.zeros((3, 2, self.n_img), dtype=self.dtype) @@ -1586,6 +1628,12 @@ def _sync_signs(self, rr, c_vec): svals[c, :, r] = S[:2] rows_arr[c, r] = U[:, 0] + return rows_arr + + def _compute_signs_adjustment(self, rows_arr): + """ + Compute signs adjustment vector. + """ # Sync signs according to results for each image. Dot products between # signed row estimates are used to construct an (N over 2)x(N over 2) # sign synchronization matrix S. If (v_i)k and (v_j)k are the i'th and @@ -1640,8 +1688,17 @@ def _sync_signs(self, rr, c_vec): signs[c] = U[:, -1] # svds returns in ascending order s_out[c] = S[::-1] - signs = np.sign(signs) + return np.sign(signs) + def _extract_rotations(self, c_mat, signs): + """ + Adjust the signs of each block of `c_mat` then extract the rotation + rows and construct the estimated rotations. + + :param c_mat: The color synchronization matrix. + :param signs: The signs adjustment matrix. + :return: Estimated rotations. + """ # Adjust the signs of Qij^c in the matrices cMat(:,:,c) for all c=1,2,3 # and 1<=i Date: Tue, 17 Sep 2024 15:33:25 -0400 Subject: [PATCH 097/105] Rewrite docstring --- src/aspire/abinitio/commonline_d2.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 29a45ff3ec..db5f799630 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1149,9 +1149,18 @@ def _sync_colors(self, Rijs): def _match_colors(self, Rijs_rows): """ - Partition the set of matrices Rijs_rows, which correspond to a permutation of - the outer products of the m'th rows of Ri and Rj, into 3 sets of matrices each - corresponding to an m'th row. Returns the permutations which induce the partition. + For each triplet of indices i < j < k, we consider the m'th row outer products stored + as Rijs_rows, ie. Rijs_rows[ij], Rijs_rows[jk], and Rijs_rows[ik]. Recall that + Rijs_rows[ij, n], n=0,1,2, corresponds to the 3x3 outer product vi_m.T @ vj_m, where + vi_m is an unknown row of the rotation matrices Ri and Rj. For each triplet of these + sets of row outer products this method finds a permutation sigma such that + Rijs_rows[ij, sigma(n)], Rijs_rows[jk, sigma(n)], and Rijs_rows[ik, sigma(n)] all + correspond to the same m'th row outer product. + + Framed as graph partioning problem we are coloring the vertices, Rijs_rows[ij, n], + with three colors such that each color corresponds to the same row of the rotations + Ris. This method returns the permutation that rearanges the elements of each triplet + of Rijs to have matching color. :param Rijs_rows: An n_pairsx3x3x3 array of m'th row outer products for the pairs Ri, Rj, where Rijs_rows[:, i] is the m'th row outer product of unknown row m. From d4a0547da2b4d4e44f9dd2b8e7a973f0e80e5c0c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 17 Sep 2024 15:54:28 -0400 Subject: [PATCH 098/105] resolve numpy deprecation warning. --- tests/test_orient_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index e9dbae5b0b..7414729799 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -124,7 +124,7 @@ def test_scl_scores(orient_est): # In this case, we take first 10 candidates from a non-equator viewing direction. orient_est._generate_lookup_data() cand_rots = orient_est.inplane_rotated_grid1 - non_eq_idx = int(np.argwhere(orient_est.eq_class1 == 0)[0]) + non_eq_idx = int(np.argwhere(orient_est.eq_class1 == 0)[0][0]) rots = cand_rots[non_eq_idx, :10] angles = Rotation(rots).angles From 2014b42e355e1b4577e00806a45c0a746c6b980b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 17 Sep 2024 15:56:45 -0400 Subject: [PATCH 099/105] try diff seed --- tests/test_orient_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 7414729799..ca972bcf7c 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -289,7 +289,7 @@ def test_sync_colors(orient_est): Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) gt_colors = np.zeros((orient_est.n_pairs, 3), dtype=int) - with Random(1234): + with Random(123): for p, (i, j) in enumerate(orient_est.pairs): gs = orient_est.gs if p > 0: From fcf1567b3cdf0681ef99e06f097f694cffc850d4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 18 Sep 2024 10:00:18 -0400 Subject: [PATCH 100/105] switch test to doubles to diagnose osx-arm failures. --- tests/test_orient_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index ca972bcf7c..f3a05821fd 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -17,7 +17,7 @@ # Parameters # ############## -DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +DTYPE = [np.float64, pytest.param(np.float32, marks=pytest.mark.expensive)] RESOLUTION = [48, 49] N_IMG = [10] OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] From 84048e0fc0a51c23952ed39486ee969a8386ffc8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 18 Sep 2024 13:16:11 -0400 Subject: [PATCH 101/105] Remove eq_idx --- src/aspire/abinitio/commonline_d2.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index db5f799630..61fd954325 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -171,8 +171,8 @@ def _generate_lookup_data(self): # We detect such directions by taking a strip of radius # `eq_min_dist` about the 3 great circles perpendicular to the symmetry # axes of D2 (i.e to X,Y and Z axes). - eq_idx1, eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) - eq_idx2, eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) + eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) + eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) # Mark Top View Directions. # A Top view projection image is taken from the direction of one of the @@ -191,9 +191,6 @@ def _generate_lookup_data(self): # Remove top views from sphere grids and update equator indices and classes. self.sphere_grid1 = sphere_grid1[eq_class1 < 4] self.sphere_grid2 = sphere_grid2[eq_class2 < 4] - self.eq_idx1 = eq_idx1[eq_class1 < 4] - self.eq_idx2 = eq_idx2[eq_class2 < 4] - self.eq_idx = np.concatenate((self.eq_idx1, self.eq_idx2)) self.eq_class1 = eq_class1[eq_class1 < 4] self.eq_class2 = eq_class2[eq_class2 < 4] @@ -209,16 +206,12 @@ def _generate_lookup_data(self): cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid1, - self.eq_idx1, - self.eq_idx1, self.eq_class1, self.eq_class1, ) cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( self.inplane_rotated_grid1, self.inplane_rotated_grid2, - self.eq_idx1, - self.eq_idx2, self.eq_class1, self.eq_class2, same_octant=False, @@ -233,8 +226,6 @@ def _generate_commonline_angles( self, Ris, Rjs, - Ri_eq_idx, - Rj_eq_idx, Ri_eq_class, Rj_eq_class, same_octant=True, @@ -246,8 +237,6 @@ def _generate_commonline_angles( :param Ris: First set of candidate rotations. :param Rjs: Second set of candidate rotation. - :param Ri_eq_idx: Equator index mask. - :param Rj_eq_idx: Equator index mask. :param Ri_eq_class: Equator classification for Ris. :param Rj_eq_class: Equator classification for Rjs. :param same_octant: True if both sets of candidates are in the same octant. @@ -259,7 +248,7 @@ def _generate_commonline_angles( # Generate upper triangular table of indicators of all pairs which are not # equators with respect to the same symmetry axis (named unique_pairs). - eq_table = np.outer(Ri_eq_idx, Rj_eq_idx) + eq_table = np.outer(Ri_eq_class > 0, Rj_eq_class > 0) in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 eq2eq_Rij_table = ~(eq_table * in_same_class) @@ -320,12 +309,10 @@ def _generate_scl_lookup_data(self): # Get self-commonline angles. self.scl_angles1 = self._generate_scl_angles( self.inplane_rotated_grid1, - self.eq_idx1, self.eq_class1, ) self.scl_angles2 = self._generate_scl_angles( self.inplane_rotated_grid2, - self.eq_idx2, self.eq_class2, ) @@ -376,7 +363,7 @@ def _generate_scl_lookup_data(self): # Generate maps from scl indices to relative rotations. self._generate_scl_scores_idx_map() - def _generate_scl_angles(self, Ris, eq_idx, eq_class): + def _generate_scl_angles(self, Ris, eq_class): """ Generate self-commonline angles. For each candidate rotation a pair of self-commonline angles are generated for each of the 3 self-commonlines induced by D2 symmetry. @@ -1847,9 +1834,7 @@ def _mark_equators(sphere_grid, eq_filter_angle): :param eq_filter_angle: Angular distance from equator to be marked as an equator point. - :returns: - - eq_idx, a boolean mask for equator indices. - - eq_class, n_rots length array of values indicating equator class. + :return: eq_class, n_rots length array of values indicating equator class. """ # Project each vector onto xy, xz, yz planes and measure angular distance # from each plane. @@ -1866,7 +1851,6 @@ def _mark_equators(sphere_grid, eq_filter_angle): # Mark all views close to an equator. eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) n_eqs = np.count_nonzero(angular_dists > eq_min_dist, axis=1) - eq_idx = n_eqs > 0 # Classify equators. # 0 -> non-equator view @@ -1889,7 +1873,7 @@ def _mark_equators(sphere_grid, eq_filter_angle): eq_view_class = np.argmax(angular_dists[eq_view_idx] > eq_min_dist, axis=1) eq_class[eq_view_idx] = eq_view_class + 1 - return eq_idx, eq_class + return eq_class @staticmethod def _generate_inplane_rots(sphere_grid, d_theta): From 357e5d314bbd17d6dcaf9c40a128a18a445e9a13 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 18 Sep 2024 14:45:46 -0400 Subject: [PATCH 102/105] Add descript of eq2eq table to docs. --- src/aspire/abinitio/commonline_d2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 61fd954325..99bbf9d898 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -233,7 +233,11 @@ def _generate_commonline_angles( """ Compute commonline angles induced by the 4 sets of relative rotations Rij = Ri.T @ g_m @ Rj, m = 0,1,2,3, where g_m is the identity and rotations - about the three axes of symmetry of a D2 symmetric molecule. + about the three axes of symmetry of a D2 symmetric molecule. Note, we only + compute commonline angles between pairs of images which are not equator + images with respect to the same axis of symmetry. To do this we build a + table, `eq2eq_Rij_table`, which is `False` for pairs of images that are + equator images with respect to the same axis of symmetry and `True` otherwise. :param Ris: First set of candidate rotations. :param Rjs: Second set of candidate rotation. From b948d7a48361d31b65b0c15d512d1895295beae2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 19 Sep 2024 08:58:24 -0400 Subject: [PATCH 103/105] add more documentation. --- src/aspire/abinitio/commonline_d2.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 99bbf9d898..847236854c 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -516,17 +516,33 @@ def _generate_scl_scores_idx_map(self): """ Generates lookup tables for maximum likelihood scheme to estimate commonlines between images. + + This method creates two lookup tables (`oct1_ij_map` and `oct2_ij_map`) + for pairs of candidate rotations (i, j) under the following conditions: + + 1. Both rotations Ri and Rj are in octant 1. + 2. Ri is in octant 1 and Rj is in octant 2. + + For each pair of candidate rotations the tables give a map into the set of + self-commonlines induced by those rotations. This table will be used later + to incorporate a likelihood score for self-commonlines into the likelihood + score for common lines for each pair of images. """ + # Calculate number of rotations in each octant. n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) - # First the map for i Date: Thu, 19 Sep 2024 13:49:45 -0400 Subject: [PATCH 104/105] Always doubles for scipy LinearOperator --- src/aspire/abinitio/commonline_d2.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 847236854c..a8e951c642 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -1146,13 +1146,15 @@ def _sync_colors(self, Rijs): color_mat = la.LinearOperator( (3 * n_pairs,) * 2, lambda v: self._mult_cmat_by_vec(color_perms, v) ) - v0 = randn( - 3 * n_pairs, seed=self.seed - ) # Seed eigs initial vector for iterative method + + # Seed eigs initial vector for iterative method. + # scipy LinearOperator needs doubles for some architectures (arm). + v0 = randn(3 * n_pairs, seed=self.seed).astype(np.float64, copy=False) + v0 = v0 / norm(v0) vals, colors = la.eigs(color_mat, k=3, which="LR", v0=v0) vals = np.real(vals) - colors = np.real(colors) + colors = np.real(colors).astype(self.dtype, copy=False) colors = np.sign(colors[0]) * colors # Stable eigs cp, _ = self._unmix_colors(colors[:, :2]) From ae9b98f18c629fcd176537559ed145851307f2df Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 19 Sep 2024 13:52:40 -0400 Subject: [PATCH 105/105] revert test to singles --- tests/test_orient_d2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index f3a05821fd..ca972bcf7c 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -17,7 +17,7 @@ # Parameters # ############## -DTYPE = [np.float64, pytest.param(np.float32, marks=pytest.mark.expensive)] +DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] RESOLUTION = [48, 49] N_IMG = [10] OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)]