Skip to content

Commit

Permalink
Merge pull request #10432 from mhvk/coord-quantity-attribute-numpy-de…
Browse files Browse the repository at this point in the history
…v-deprecation

Check values equal to zero in QuantityAttribute a bit more carefully.
  • Loading branch information
adrn authored and astrofrog committed Jun 5, 2020
1 parent a7cedc1 commit 63116d3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
32 changes: 18 additions & 14 deletions astropy/coordinates/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,24 @@ def convert_input(self, value):
if value is None:
return None, False

if np.all(value == 0) and self.unit is not None:
return u.Quantity(np.zeros(self.shape), self.unit), True
else:
if not hasattr(value, 'unit') and self.unit != u.dimensionless_unscaled:
raise TypeError('Tried to set a QuantityAttribute with '
'something that does not have a unit.')
oldvalue = value
value = u.Quantity(oldvalue, self.unit, copy=False)
if self.shape is not None and value.shape != self.shape:
raise ValueError('The provided value has shape "{}", but '
'should have shape "{}"'.format(value.shape,
self.shape))
converted = oldvalue is not value
return value, converted
if (not hasattr(value, 'unit') and self.unit != u.dimensionless_unscaled
and np.any(value != 0)):
raise TypeError('Tried to set a QuantityAttribute with '
'something that does not have a unit.')

oldvalue = value
value = u.Quantity(oldvalue, self.unit, copy=False)
if self.shape is not None and value.shape != self.shape:
if value.shape == () and oldvalue == 0:
# Allow a single 0 to fill whatever shape is needed.
value = np.broadcast_to(value, self.shape, subok=True)
else:
raise ValueError(
f'The provided value has shape "{value.shape}", but '
f'should have shape "{self.shape}"')

converted = oldvalue is not value
return value, converted


class EarthLocationAttribute(Attribute):
Expand Down
32 changes: 32 additions & 0 deletions astropy/coordinates/tests/test_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,38 @@ class MyCoord2(BaseCoordinateFrame):
frame = MyCoord2()
assert u.isclose(frame.someval, 15*u.deg)

# Since here no shape was given, we can set to any shape we like.
frame = MyCoord2(someval=np.ones(3)*u.deg)
assert frame.someval.shape == (3,)
assert np.all(frame.someval == 1*u.deg)

# We should also be able to insist on a given shape.
class MyCoord3(BaseCoordinateFrame):
someval = QuantityAttribute(unit=u.arcsec, shape=(3,))

frame = MyCoord3(someval=np.ones(3)*u.deg)
assert frame.someval.shape == (3,)
assert frame.someval.unit == u.arcsec
assert u.allclose(frame.someval.value, 3600.)

# The wrong shape raises.
with pytest.raises(ValueError, match='shape'):
MyCoord3(someval=1.*u.deg)

# As does the wrong unit.
with pytest.raises(u.UnitsError):
MyCoord3(someval=np.ones(3)*u.m)

# We are allowed a short-cut for zero.
frame0 = MyCoord3(someval=0)
assert frame0.someval.shape == (3,)
assert frame0.someval.unit == u.arcsec
assert np.all(frame0.someval.value == 0.)

# But not if it has the wrong shape.
with pytest.raises(ValueError, match='shape'):
MyCoord3(someval=np.zeros(2))

# This should fail, if we don't pass in a default or a unit
with pytest.raises(ValueError):
class MyCoord(BaseCoordinateFrame):
Expand Down

0 comments on commit 63116d3

Please sign in to comment.