Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: implement a base comparison operator (__eq__) for NDData #15903

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions astropy/nddata/nddata.py
Expand Up @@ -444,3 +444,26 @@
# setter).
value.parent_nddata = self
self._uncertainty = value

def __eq__(self, other) -> bool:
def is_sequence(a) -> bool:
try:
len(a)
except TypeError:
return False

Check warning on line 453 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L449-L453

Added lines #L449 - L453 were not covered by tests
else:
return True

Check warning on line 455 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L455

Added line #L455 was not covered by tests

def seqeq(a, b) -> bool:

Check warning on line 457 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L457

Added line #L457 was not covered by tests
# return True iff a==b
if is_sequence(a) and is_sequence(b):
return bool(np.all(np.asanyarray(a) == np.asanyarray(b)))

Check warning on line 460 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L459-L460

Added lines #L459 - L460 were not covered by tests
else:
return a == b

Check warning on line 462 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L462

Added line #L462 was not covered by tests

# checks are ordered from cheapest to most expensive
return other is self or (

Check warning on line 465 in astropy/nddata/nddata.py

View check run for this annotation

Codecov / codecov/patch

astropy/nddata/nddata.py#L465

Added line #L465 was not covered by tests
self.unit == other.unit
and seqeq(self.mask, other.mask)
and seqeq(self.data, other.data)
)
25 changes: 25 additions & 0 deletions astropy/nddata/tests/test_nddata.py
Expand Up @@ -707,3 +707,28 @@ def test_collapse(mask, unit, propagate_uncertainties, operation_ignores_mask):
# as the data array, so we can just check for equality:
if method in ext_methods and propagate_uncertainties:
assert np.ma.all(np.ma.equal(astropy_method, nddata_method))


@pytest.mark.parametrize(
"nd1, nd2, expected",
[
(NDData([1]), None, True),
(NDData([1], mask=True), None, True),
(NDData([1], mask=False), None, True),
(NDData([1], unit=u.J), None, True),
(NDData([1]), NDData([1]), True),
(NDData([1]), NDData([1], mask=True), False),
(NDData([1]), NDData([1], unit=u.J), False),
(NDData([1], mask=True), NDData([1], mask=True, unit=u.J), False),
(NDData([1], mask=False, unit=u.J), NDData([1], mask=True, unit=u.J), False),
(NDData([1], mask=True, unit=u.K), NDData([1], mask=True, unit=u.J), False),
(NDData([1], mask=True, unit=u.J), NDData([1], mask=True, unit=u.J), True),
],
)
def test_nddata_eq(nd1, nd2, expected):
if nd2 is None:
# special case to check that we don't break default comparison
# __eq__(self, other) = lambda self, other: self is other
nd2 = nd1
assert (nd1 == nd2) is expected
assert (nd1 != nd2) ^ expected
1 change: 1 addition & 0 deletions docs/changes/15903.other.rst
@@ -0,0 +1 @@
``NDData`` objects now supports comparison operators ``==`` and ``!=``.