Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atol argument to function is_O3 #14371

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 28 additions & 6 deletions astropy/coordinates/matrix_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,19 @@ def angle_axis(matrix):
return Angle(angle, u.radian), -axis / r


def is_O3(matrix):
def is_O3(matrix, atol=None):
"""Check whether a matrix is in the length-preserving group O(3).

Parameters
----------
matrix : (..., N, N) array-like
Must have attribute ``.shape`` and method ``.swapaxes()`` and not error
when using `~numpy.isclose`.
atol : float, optional
The allowed absolute difference.
If `None` it defaults to 1e-15 or 5 * epsilon of the matrix's dtype, if floating.

.. versionadded:: 5.3

Returns
-------
Expand All @@ -159,14 +164,20 @@ def is_O3(matrix):
"""
# matrix is in O(3) (rotations, proper and improper).
I = np.identity(matrix.shape[-1])
nstarman marked this conversation as resolved.
Show resolved Hide resolved
if atol is None:
if np.issubdtype(matrix.dtype, np.floating):
atol = np.finfo(matrix.dtype).eps * 5
else:
atol = 1e-15

is_o3 = np.all(
np.isclose(matrix @ matrix.swapaxes(-2, -1), I, atol=1e-15), axis=(-2, -1)
np.isclose(matrix @ matrix.swapaxes(-2, -1), I, atol=atol), axis=(-2, -1)
)

return is_o3


def is_rotation(matrix, allow_improper=False):
def is_rotation(matrix, allow_improper=False, atol=None):
"""Check whether a matrix is a rotation, proper or improper.

Parameters
Expand All @@ -178,6 +189,11 @@ def is_rotation(matrix, allow_improper=False):
Whether to restrict check to the SO(3), the group of proper rotations,
or also allow improper rotations (with determinant -1).
The default (False) is only SO(3).
atol : float, optional
The allowed absolute difference.
If `None` it defaults to 1e-15 or 5 * epsilon of the matrix's dtype, if floating.

.. versionadded:: 5.3

Returns
-------
Expand All @@ -198,13 +214,19 @@ def is_rotation(matrix, allow_improper=False):
For more information, see https://en.wikipedia.org/wiki/Orthogonal_group

"""
if atol is None:
if np.issubdtype(matrix.dtype, np.floating):
atol = np.finfo(matrix.dtype).eps * 5
else:
atol = 1e-15

# matrix is in O(3).
is_o3 = is_O3(matrix)
is_o3 = is_O3(matrix, atol=atol)

# determinant checks for rotation (proper and improper)
if allow_improper: # determinant can be +/- 1
is_det1 = np.isclose(np.abs(np.linalg.det(matrix)), 1.0)
is_det1 = np.isclose(np.abs(np.linalg.det(matrix)), 1.0, atol=atol)
else: # restrict to SO(3)
is_det1 = np.isclose(np.linalg.det(matrix), 1.0)
is_det1 = np.isclose(np.linalg.det(matrix), 1.0, atol=atol)

return is_o3 & is_det1
8 changes: 8 additions & 0 deletions astropy/coordinates/tests/test_matrix_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def test_is_O3():
# and (M, 3, 3)
n1 = np.tile(m1, (2, 1, 1))
assert tuple(is_O3(n1)) == (True, True) # (show the broadcasting)
# Test atol parameter
nn1 = np.tile(0.5 * m1, (2, 1, 1))
assert tuple(is_O3(nn1)) == (False, False) # (show the broadcasting)
assert tuple(is_O3(nn1, atol=1)) == (True, True) # (show the broadcasting)

# reflection
m2 = m1.copy()
Expand All @@ -98,6 +102,10 @@ def test_is_rotation():
# and (M, 3, 3)
n1 = np.tile(m1, (2, 1, 1))
assert tuple(is_rotation(n1)) == (True, True) # (show the broadcasting)
# Test atol parameter
nn1 = np.tile(0.5 * m1, (2, 1, 1))
assert tuple(is_rotation(nn1)) == (False, False) # (show the broadcasting)
assert tuple(is_rotation(nn1, atol=10)) == (True, True) # (show the broadcasting)

# Improper rotation (unit rotation + reflection)
m2 = np.identity(3)
Expand Down
1 change: 1 addition & 0 deletions docs/changes/coordinates/14371.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``atol`` argument to function ``is_O3`` and ``is_rotation`` in matrix utilities.