Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/aspire/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ def nearest_rotations(A, allow_reflection=False):
# R = (U * -1 @ VT) * [-1,-1,1]
# R = (U @ VT) * (-1 * [-1,-1,1])
# R = U @ (VT * [1,1,-1])
#
# See:
#
# Gower, J.C. (1976), Procrustes rotation problems. The
# Mathematical Scientist, 1 (Supplement), 12-15.
#
# Ten Berge, J.M.F. (2006), The Rigid Orthogonal Procrustes
# Rotation Problem. Psychometrika, vol. 71, no., 1, 201-205.

d = np.array([1, 1, -1], dtype=dtype)
neg_det_idx = np.linalg.det(U) * np.linalg.det(VT) < 0
VT[neg_det_idx] = VT[neg_det_idx] * d
Expand Down
47 changes: 19 additions & 28 deletions src/aspire/utils/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.spatial.transform import Rotation as sp_rot

from aspire.utils.random import Random
from aspire.utils.matrix import nearest_rotations


class Rotation:
Expand Down Expand Up @@ -107,8 +108,8 @@ def find_registration(self, rots_ref):

:param rots_ref: The reference Rotation object to which we would like to align
with data matrices in the form of a n-by-3-by-3 array.
:return: o_mat, optimal orthogonal 3x3 matrix to align the two sets;
flag, flag==1 then J conjugacy is required and 0 is not.
:return: Q, optimal orthogonal 3x3 matrix to align the two sets;
J_conj, specifies whether J-conjugacy is needed.
"""
rots = self._matrices
rots_ref = rots_ref.matrices.astype(self.dtype)
Expand All @@ -129,44 +130,34 @@ def find_registration(self, rots_ref):
Q1 = Q1 + R @ Rref.T
Q2 = Q2 + (J @ R @ J) @ Rref.T

# Compute the two possible orthogonal matrices which register the
# estimated rotations to the true ones.
Q1 = Q1 / K
Q2 = Q2 / K
Q1 = nearest_rotations(Q1)
Q2 = nearest_rotations(Q2)

# We are registering one set of rotations (the estimated ones) to
# another set of rotations (the true ones). Thus, the transformation
# matrix between the two sets of rotations should be orthogonal. This
# matrix is either Q1 if we recover the non-reflected solution, or Q2,
# if we got the reflected one. In any case, one of them should be
# orthogonal.
# if we got the reflected one.

err1 = norm(Q1 @ Q1.T - np.eye(3), ord="fro")
err2 = norm(Q2 @ Q2.T - np.eye(3), ord="fro")
err1 = np.sum([norm(Q1.T @ R - Rref, "fro") ** 2
for R, Rref in zip(rots, rots_ref)])
err2 = np.sum([norm(Q2.T @ (J @ R @ J) - Rref, "fro") ** 2
for R, Rref in zip(rots, rots_ref)])

# In any case, enforce the registering matrix O to be a rotation.
if err1 < err2:
# Use Q1 as the registering matrix
U, _, V = svd(Q1)
flag = 0
return Q1, False
else:
# Use Q2 as the registering matrix
U, _, V = svd(Q2)
flag = 1
return Q2, True

Q_mat = U @ V

return Q_mat, flag

def apply_registration(self, Q_mat, flag):
def apply_registration(self, Q, J_conj):
"""
Get aligned Rotation object to reference ones.

Calculated aligned rotation matrices from the orthogonal transformation
that best aligns the estimated rotations to the reference rotations.

:param Q_mat: optimal orthogonal 3x3 transformation matrix
:param flag: flag==1 then J conjugacy is required and 0 is not
:param Q: optimal orthogonal 3x3 transformation matrix
:param J_conj: whether J-conjugacy is required
:return: regrot, aligned Rotation object
"""
rots = self._matrices
Expand All @@ -178,9 +169,9 @@ def apply_registration(self, Q_mat, flag):
regrot = np.zeros_like(rots)
for k in range(K):
R = rots[k, :, :]
if flag == 1:
if J_conj:
R = J @ R @ J
regrot[k, :, :] = Q_mat.T @ R
regrot[k, :, :] = Q.T @ R
aligned_rots = Rotation(regrot)
return aligned_rots

Expand All @@ -192,8 +183,8 @@ def register(self, rots_ref):
to align with data matrices in the form of a n-by-3-by-3 array.
:return: an aligned Rotation object
"""
Q_mat, flag = self.find_registration(rots_ref)
return self.apply_registration(Q_mat, flag)
Q, J_conj = self.find_registration(rots_ref)
return self.apply_registration(Q, J_conj)

def mse(self, rots_ref):
"""
Expand Down
Loading