Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Distribution can be used in Latitude and Longitude #14421

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions astropy/coordinates/angles.py
Expand Up @@ -553,11 +553,17 @@ def _validate_angles(self, angles=None):
if angles.unit is u.deg:
limit = 90
elif angles.unit is u.rad:
limit = self.dtype.type(0.5 * np.pi)
limit = 0.5 * np.pi
else:
limit = u.degree.to(angles.unit, 90.0)

invalid_angles = np.any(angles.value < -limit) or np.any(angles.value > limit)
# Ensure ndim>=1 so that comparison is done using the angle dtype.
# Otherwise, e.g., np.array(np.pi/2, 'f4') > np.pi/2 will yield True.
# (This feels like a bug -- see https://github.com/numpy/numpy/issues/23247)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like since this was written there might be a fix in a future numpy for this - might want to add a todo note aobut that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Note that we should avoid using `angles.dtype` directly since for
# structured arrays like Distribution this will be `void`.
angles_view = angles.view(np.ndarray)[np.newaxis]
invalid_angles = np.any(angles_view < -limit) or np.any(angles_view > limit)
if invalid_angles:
raise ValueError(
"Latitude angle(s) must be within -90 deg <= angle <= 90 deg, "
Expand Down
27 changes: 27 additions & 0 deletions astropy/uncertainty/core.py
Expand Up @@ -311,6 +311,9 @@ def view(self, dtype=None, type=None):

# Override __getitem__ so that 'samples' is returned as the sample class.
def __getitem__(self, item):
if isinstance(item, Distribution):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thing it might be better to use duck-typing here - will make a follow-on PR with a proposal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after looking more closely I realize this is a bigger topic that's not worth worrying about right now, because it would be more consistent to do this with all the isinstance calls in this sub-package, which might have unintended consequences...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. While usually a big fan of duck-typing, I think in this instance, isinstance may actually be OK; e.g., I don't think we can count on other implementations to use .distribution as an attribute... But something to think about nevertheless...

# Required for in-place operations like dist[dist < 0] += 360.
return self.distribution[item.distribution]
result = super().__getitem__(item)
if item == "samples":
# Here, we need to avoid our own redefinition of view.
Expand All @@ -320,6 +323,30 @@ def __getitem__(self, item):
else:
return result

def __setitem__(self, item, value):
if isinstance(item, Distribution):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same here as above)

# Support operations like dist[dist < 0] = 0.
self.distribution[item.distribution] = value
else:
super().__setitem__(item, value)

# Override __eq__ and __ne__ to pass on directly to the ufunc since
# otherwise comparisons with non-distributions do not work (but
# deferring if other defines __array_ufunc__ = None -- see
# numpy/core/src/common/binop_override.h for the logic; we assume we
# will never deal with __array_priority__ any more). Note: there is no
# problem for other comparisons, since for those, structured arrays are
# not treated differently in numpy/core/src/multiarray/arrayobject.c.
def __eq__(self, other):
if getattr(other, "__array_ufunc__", False) is None:
return NotImplemented
return np.equal(self, other)

def __ne__(self, other):
if getattr(other, "__array_ufunc__", False) is None:
return NotImplemented
return np.not_equal(self, other)


class _DistributionRepr:
def __repr__(self):
Expand Down
72 changes: 72 additions & 0 deletions astropy/uncertainty/tests/test_containers.py
@@ -0,0 +1,72 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Test that Distribution works with classes other than ndarray and Quantity."""

import numpy as np
import pytest
from numpy.testing import assert_array_equal

import astropy.units as u
from astropy.coordinates import Angle, Latitude, Longitude
from astropy.uncertainty import Distribution


class TestAngles:
@classmethod
def setup_class(cls):
cls.a = np.arange(27.0).reshape(3, 9)
cls.d = Distribution(cls.a)
cls.q = cls.a << u.deg
cls.dq = Distribution(cls.q)

@pytest.mark.parametrize("angle_cls", [Angle, Longitude, Latitude])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe should add EarthLocation to this list? Although that might be a whole can of worms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EarthLocation will be interesting, since it is already a structured type. That will really need some thought, since presumably each sample then has x,y,z inside, but we still have to be sure el_dist['x'] works... In other words, I'm pretty sure this will not yet work!

def test_as_input_for_angle(self, angle_cls):
da = angle_cls(self.dq)
assert isinstance(da, angle_cls)
assert isinstance(da, Distribution)
assert_array_equal(da.distribution, angle_cls(self.q))

@pytest.mark.parametrize("angle_cls", [Angle, Longitude, Latitude])
def test_using_angle_as_input(self, angle_cls):
a = angle_cls(self.q)
da = Distribution(a)
assert isinstance(da, angle_cls)
assert isinstance(da, Distribution)

# Parametrize the unit to check the various branches in Latitude._validate_angles
@pytest.mark.parametrize("dtype", ["f8", "f4"])
@pytest.mark.parametrize(
"value", [90 * u.deg, np.pi / 2 * u.radian, 90 * 60 * u.arcmin]
)
def test_at_limit_for_latitude(self, value, dtype):
q = u.Quantity(value, dtype=dtype).reshape(1)
qd = Distribution(q)
ld = Latitude(qd)
assert_array_equal(ld.distribution, Latitude(q))

# Parametrize the unit in case Longitude._wrap_at becomes unit-dependent.
@pytest.mark.parametrize("dtype", ["f8", "f4"])
@pytest.mark.parametrize(
"value", [360 * u.deg, 2 * np.pi * u.radian, 360 * 60 * u.arcmin]
)
def test_at_wrap_angle_for_longitude(self, value, dtype):
q = u.Quantity(value, dtype=dtype).reshape(1)
qd = Distribution(q)
ld = Longitude(qd)
assert_array_equal(ld.distribution, Longitude(q))
assert np.all(ld.distribution == 0)

@pytest.mark.parametrize("angle_cls", [Longitude, Latitude])
def test_operation_gives_correct_subclass(self, angle_cls):
# Lon and Lat always fall back to Angle
da = angle_cls(self.dq)
da2 = da + da
assert isinstance(da, Angle)
assert isinstance(da, Distribution)

@pytest.mark.parametrize("angle_cls", [Longitude, Latitude])
def test_pdfstd_gives_correct_subclass(self, angle_cls):
# Lon and Lat always fall back to Angle
da = angle_cls(self.dq)
std = da.pdf_std()
assert isinstance(std, Angle)
assert_array_equal(std, Angle(self.q.std(-1)))
42 changes: 42 additions & 0 deletions astropy/uncertainty/tests/test_distribution.py
@@ -1,4 +1,5 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import operator

import numpy as np
import pytest
Expand Down Expand Up @@ -471,3 +472,44 @@ def test_scalar_quantity_distribution():
assert isinstance(sin_angles, Distribution)
assert isinstance(sin_angles, u.Quantity)
assert_array_equal(sin_angles, Distribution(np.sin([90.0, 30.0, 0.0] * u.deg)))


@pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.gt])
class TestComparison:
@classmethod
def setup_class(cls):
cls.d = Distribution([90.0, 30.0, 0.0])

class Override:
__array_ufunc__ = None

def __eq__(self, other):
return "eq"

def __ne__(self, other):
return "ne"

def __lt__(self, other):
return "gt" # Since it is called for the reverse of gt

cls.override = Override()

def test_distribution_can_be_compared_to_non_distribution(self, op):
result = op(self.d, 0.0)
assert_array_equal(result, Distribution(op(self.d.distribution, 0.0)))

def test_distribution_comparison_defers_correctly(self, op):
result = op(self.d, self.override)
assert result == op.__name__


class TestSetItemWithSelection:
def test_setitem(self):
d = Distribution([90.0, 30.0, 0.0])
d[d > 50] = 0.0
assert_array_equal(d, Distribution([0.0, 30.0, 0.0]))

def test_inplace_operation(self):
d = Distribution([90.0, 30.0, 0.0])
d[d > 50] *= -1.0
assert_array_equal(d, Distribution([-90.0, 30.0, 0.0]))
3 changes: 3 additions & 0 deletions docs/changes/uncertainty/14421.bugfix.rst
@@ -0,0 +1,3 @@
Ensure that ``Distribution`` can be compared with ``==`` and ``!=``
with regular arrays or scalars, and that inplace operations like
``dist[dist<0] *= -1`` work.