diff --git a/src/aspire/utils/matrix.py b/src/aspire/utils/matrix.py index d5d495365c..f48f9bbac5 100644 --- a/src/aspire/utils/matrix.py +++ b/src/aspire/utils/matrix.py @@ -336,6 +336,15 @@ def nearest_rotations(A, allow_reflection=False): # R = (U * -1 @ VT) * [-1,-1,1] # R = (U @ VT) * (-1 * [-1,-1,1]) # R = U @ (VT * [1,1,-1]) + # + # See: + # + # Gower, J.C. (1976), Procrustes rotation problems. The + # Mathematical Scientist, 1 (Supplement), 12-15. + # + # Ten Berge, J.M.F. (2006), The Rigid Orthogonal Procrustes + # Rotation Problem. Psychometrika, vol. 71, no., 1, 201-205. + d = np.array([1, 1, -1], dtype=dtype) neg_det_idx = np.linalg.det(U) * np.linalg.det(VT) < 0 VT[neg_det_idx] = VT[neg_det_idx] * d diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 2a2e0f98a0..720b9a7898 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -10,6 +10,7 @@ from scipy.spatial.transform import Rotation as sp_rot from aspire.utils.random import Random +from aspire.utils.matrix import nearest_rotations class Rotation: @@ -107,8 +108,8 @@ def find_registration(self, rots_ref): :param rots_ref: The reference Rotation object to which we would like to align with data matrices in the form of a n-by-3-by-3 array. - :return: o_mat, optimal orthogonal 3x3 matrix to align the two sets; - flag, flag==1 then J conjugacy is required and 0 is not. + :return: Q, optimal orthogonal 3x3 matrix to align the two sets; + J_conj, specifies whether J-conjugacy is needed. """ rots = self._matrices rots_ref = rots_ref.matrices.astype(self.dtype) @@ -129,44 +130,34 @@ def find_registration(self, rots_ref): Q1 = Q1 + R @ Rref.T Q2 = Q2 + (J @ R @ J) @ Rref.T - # Compute the two possible orthogonal matrices which register the - # estimated rotations to the true ones. - Q1 = Q1 / K - Q2 = Q2 / K + Q1 = nearest_rotations(Q1) + Q2 = nearest_rotations(Q2) # We are registering one set of rotations (the estimated ones) to # another set of rotations (the true ones). Thus, the transformation # matrix between the two sets of rotations should be orthogonal. This # matrix is either Q1 if we recover the non-reflected solution, or Q2, - # if we got the reflected one. In any case, one of them should be - # orthogonal. + # if we got the reflected one. - err1 = norm(Q1 @ Q1.T - np.eye(3), ord="fro") - err2 = norm(Q2 @ Q2.T - np.eye(3), ord="fro") + err1 = np.sum([norm(Q1.T @ R - Rref, "fro") ** 2 + for R, Rref in zip(rots, rots_ref)]) + err2 = np.sum([norm(Q2.T @ (J @ R @ J) - Rref, "fro") ** 2 + for R, Rref in zip(rots, rots_ref)]) - # In any case, enforce the registering matrix O to be a rotation. if err1 < err2: - # Use Q1 as the registering matrix - U, _, V = svd(Q1) - flag = 0 + return Q1, False else: - # Use Q2 as the registering matrix - U, _, V = svd(Q2) - flag = 1 + return Q2, True - Q_mat = U @ V - - return Q_mat, flag - - def apply_registration(self, Q_mat, flag): + def apply_registration(self, Q, J_conj): """ Get aligned Rotation object to reference ones. Calculated aligned rotation matrices from the orthogonal transformation that best aligns the estimated rotations to the reference rotations. - :param Q_mat: optimal orthogonal 3x3 transformation matrix - :param flag: flag==1 then J conjugacy is required and 0 is not + :param Q: optimal orthogonal 3x3 transformation matrix + :param J_conj: whether J-conjugacy is required :return: regrot, aligned Rotation object """ rots = self._matrices @@ -178,9 +169,9 @@ def apply_registration(self, Q_mat, flag): regrot = np.zeros_like(rots) for k in range(K): R = rots[k, :, :] - if flag == 1: + if J_conj: R = J @ R @ J - regrot[k, :, :] = Q_mat.T @ R + regrot[k, :, :] = Q.T @ R aligned_rots = Rotation(regrot) return aligned_rots @@ -192,8 +183,8 @@ def register(self, rots_ref): to align with data matrices in the form of a n-by-3-by-3 array. :return: an aligned Rotation object """ - Q_mat, flag = self.find_registration(rots_ref) - return self.apply_registration(Q_mat, flag) + Q, J_conj = self.find_registration(rots_ref) + return self.apply_registration(Q, J_conj) def mse(self, rots_ref): """