Skip to content

Commit

Permalink
Merge pull request #14371 from cmarmo/is-o3-atol
Browse files Browse the repository at this point in the history
Add ``atol`` argument to function ``is_O3``
  • Loading branch information
mhvk committed Feb 9, 2023
2 parents c6e486d + 20ab22e commit 24a4ee1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
34 changes: 28 additions & 6 deletions astropy/coordinates/matrix_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,19 @@ def angle_axis(matrix):
return Angle(angle, u.radian), -axis / r


def is_O3(matrix):
def is_O3(matrix, atol=None):
"""Check whether a matrix is in the length-preserving group O(3).
Parameters
----------
matrix : (..., N, N) array-like
Must have attribute ``.shape`` and method ``.swapaxes()`` and not error
when using `~numpy.isclose`.
atol : float, optional
The allowed absolute difference.
If `None` it defaults to 1e-15 or 5 * epsilon of the matrix's dtype, if floating.
.. versionadded:: 5.3
Returns
-------
Expand All @@ -159,14 +164,20 @@ def is_O3(matrix):
"""
# matrix is in O(3) (rotations, proper and improper).
I = np.identity(matrix.shape[-1])
if atol is None:
if np.issubdtype(matrix.dtype, np.floating):
atol = np.finfo(matrix.dtype).eps * 5
else:
atol = 1e-15

is_o3 = np.all(
np.isclose(matrix @ matrix.swapaxes(-2, -1), I, atol=1e-15), axis=(-2, -1)
np.isclose(matrix @ matrix.swapaxes(-2, -1), I, atol=atol), axis=(-2, -1)
)

return is_o3


def is_rotation(matrix, allow_improper=False):
def is_rotation(matrix, allow_improper=False, atol=None):
"""Check whether a matrix is a rotation, proper or improper.
Parameters
Expand All @@ -178,6 +189,11 @@ def is_rotation(matrix, allow_improper=False):
Whether to restrict check to the SO(3), the group of proper rotations,
or also allow improper rotations (with determinant -1).
The default (False) is only SO(3).
atol : float, optional
The allowed absolute difference.
If `None` it defaults to 1e-15 or 5 * epsilon of the matrix's dtype, if floating.
.. versionadded:: 5.3
Returns
-------
Expand All @@ -198,13 +214,19 @@ def is_rotation(matrix, allow_improper=False):
For more information, see https://en.wikipedia.org/wiki/Orthogonal_group
"""
if atol is None:
if np.issubdtype(matrix.dtype, np.floating):
atol = np.finfo(matrix.dtype).eps * 5
else:
atol = 1e-15

# matrix is in O(3).
is_o3 = is_O3(matrix)
is_o3 = is_O3(matrix, atol=atol)

# determinant checks for rotation (proper and improper)
if allow_improper: # determinant can be +/- 1
is_det1 = np.isclose(np.abs(np.linalg.det(matrix)), 1.0)
is_det1 = np.isclose(np.abs(np.linalg.det(matrix)), 1.0, atol=atol)
else: # restrict to SO(3)
is_det1 = np.isclose(np.linalg.det(matrix), 1.0)
is_det1 = np.isclose(np.linalg.det(matrix), 1.0, atol=atol)

return is_o3 & is_det1
8 changes: 8 additions & 0 deletions astropy/coordinates/tests/test_matrix_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def test_is_O3():
# and (M, 3, 3)
n1 = np.tile(m1, (2, 1, 1))
assert tuple(is_O3(n1)) == (True, True) # (show the broadcasting)
# Test atol parameter
nn1 = np.tile(0.5 * m1, (2, 1, 1))
assert tuple(is_O3(nn1)) == (False, False) # (show the broadcasting)
assert tuple(is_O3(nn1, atol=1)) == (True, True) # (show the broadcasting)

# reflection
m2 = m1.copy()
Expand All @@ -98,6 +102,10 @@ def test_is_rotation():
# and (M, 3, 3)
n1 = np.tile(m1, (2, 1, 1))
assert tuple(is_rotation(n1)) == (True, True) # (show the broadcasting)
# Test atol parameter
nn1 = np.tile(0.5 * m1, (2, 1, 1))
assert tuple(is_rotation(nn1)) == (False, False) # (show the broadcasting)
assert tuple(is_rotation(nn1, atol=10)) == (True, True) # (show the broadcasting)

# Improper rotation (unit rotation + reflection)
m2 = np.identity(3)
Expand Down
1 change: 1 addition & 0 deletions docs/changes/coordinates/14371.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``atol`` argument to function ``is_O3`` and ``is_rotation`` in matrix utilities.

0 comments on commit 24a4ee1

Please sign in to comment.