Skip to content

Commit

Permalink
Ensure structured SpecificTypeQuantity is possible.
Browse files Browse the repository at this point in the history
Turned out this needed a fix to how physical type IDs were returned,
as a plain numpy.void gave a FutureWarning on comparisons, and
is sensitive to names, which we do not want to be.
  • Loading branch information
mhvk committed May 24, 2021
1 parent 16945c7 commit 4705573
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 15 deletions.
46 changes: 40 additions & 6 deletions astropy/units/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,22 @@ def __iter__(self):
yield from self._units.dtype.names

# Helpers for methods below.
def _recursively_apply(self, func, as_void=False):
def _recursively_apply(self, func, cls=None):
"""Apply func recursively.
The result is stored in an instance of cls or a void.
Parameters
----------
func : callable
Function to apply to all parts of the structured unit,
recursing as needed.
cls : type, optional
If given, should be a subclass of `~numpy.void`. By default,
will return a new `~astropy.units.StructuredUnit` instance.
"""
results = np.array(tuple([func(part) for part in self.values()]),
self._units.dtype)[()]
if as_void:
return results
if cls is not None:
return results.view((cls, results.dtype))

# Short-cut; no need to interpret field names, etc.
result = super().__new__(self.__class__)
Expand Down Expand Up @@ -230,13 +237,13 @@ def cgs(self):
# Needed to pass through Unit initializer, so might as well use it.
def _get_physical_type_id(self):
return self._recursively_apply(
operator.methodcaller('_get_physical_type_id'), as_void=True)
operator.methodcaller('_get_physical_type_id'), cls=Structure)

@property
def physical_type(self):
"""Physical types of all the fields."""
return self._recursively_apply(
operator.attrgetter('physical_type'), as_void=True)
operator.attrgetter('physical_type'), cls=Structure)

def decompose(self, bases=set()):
"""The `StructuredUnit` composed of only irreducible units.
Expand Down Expand Up @@ -429,3 +436,30 @@ def __ne__(self, other):
return NotImplemented

return self.values() != other.values()


class Structure(np.void):
"""Single element structure for physical type IDs, etc.
Behaves like a `~numpy.void` and thus mostly like a tuple which can also
be indexed with field names, but overrides ``__eq__`` and ``__ne__`` to
compare only the contents, not the field names are ignored. Furthermore,
this way no `FutureWarning` about comparisons is given.
"""
# Note that it is important that it is important for physical type IDs to
# not be stored in a tuple, since then the physical types would be treated
# as alternatives in :meth:`~astropy.units.UnitBase.is_equivalent`.
# (Of course, in that case, they could also not be indexed by name.)

def __eq__(self, other):
if isinstance(other, np.void):
other = other.item()

return self.item() == other

def __ne__(self, other):
if isinstance(other, np.void):
other = other.item()

return self.item() != other
57 changes: 48 additions & 9 deletions astropy/units/tests/test_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,34 @@ def test_setitem_fails(self):
class TestStructuredUnitMethods(StructuredTestBaseWithUnits):
def test_physical_type_id(self):
pv_ptid = self.pv_unit._get_physical_type_id()
expected = np.array((self.pv_unit['p']._get_physical_type_id(),
self.pv_unit['v']._get_physical_type_id()),
[('p', 'O'), ('v', 'O')])
assert len(pv_ptid) == 2
assert pv_ptid.dtype.names == ('p', 'v')
p_ptid = self.pv_unit['p']._get_physical_type_id()
v_ptid = self.pv_unit['v']._get_physical_type_id()
# Expected should be (subclass of) void, with structured object dtype.
expected = np.array((p_ptid, v_ptid), [('p', 'O'), ('v', 'O')])[()]
assert pv_ptid == expected
# Names should be ignored in comparison.
assert pv_ptid == np.array((p_ptid, v_ptid), 'O,O')[()]
# Should be possible to address by field and by number.
assert pv_ptid['p'] == p_ptid
assert pv_ptid['v'] == v_ptid
assert pv_ptid[0] == p_ptid
assert pv_ptid[1] == v_ptid
# More complicated version.
pv_t_ptid = self.pv_t_unit._get_physical_type_id()
expected2 = np.array((self.pv_unit._get_physical_type_id(),
self.t_unit._get_physical_type_id()),
[('pv', 'O'), ('t', 'O')])
assert pv_t_ptid == expected2
t_ptid = self.t_unit._get_physical_type_id()
assert pv_t_ptid == np.array((pv_ptid, t_ptid), 'O,O')[()]
assert pv_t_ptid['pv'] == pv_ptid
assert pv_t_ptid['t'] == t_ptid
assert pv_t_ptid['pv'][1] == v_ptid

def test_physical_type(self):
pv_pt = self.pv_unit.physical_type
assert pv_pt == np.array(('length', 'speed'), [('p', 'O'), ('v', 'O')])
assert pv_pt == np.array(('length', 'speed'), 'O,O')[()]

pv_t_pt = self.pv_t_unit.physical_type
assert pv_t_pt == np.array((pv_pt, 'time'), [('pv', 'O'), ('t', 'O')])
assert pv_t_pt == np.array((pv_pt, 'time'), 'O,O')[()]

def test_si(self):
pv_t_si = self.pv_t_unit.si
Expand All @@ -274,6 +286,8 @@ def test_is_equivalent(self):
pv_alt = StructuredUnit('m,m/s', names=('q', 'w'))
assert pv_alt.field_names != self.pv_unit.field_names
assert self.pv_unit.is_equivalent(pv_alt)
# Regular units should work too.
assert not u.m.is_equivalent(self.pv_unit)

def test_conversion(self):
pv1 = self.pv_unit.to(('AU', 'AU/day'), self.pv)
Expand Down Expand Up @@ -566,6 +580,31 @@ def test_zeros_ones_like(self, func):
assert_array_equal(z, func(self.pv) << self.pv_unit)


class TestStructuredSpecificTypeQuantity(StructuredTestBaseWithUnits):
def setup_class(self):
super().setup_class()

class PositionVelocity(u.SpecificTypeQuantity):
_equivalent_unit = self.pv_unit

self.PositionVelocity = PositionVelocity

def test_init(self):
pv = self.PositionVelocity(self.pv, self.pv_unit)
assert isinstance(pv, self.PositionVelocity)
assert type(pv['p']) is u.Quantity
assert_array_equal(pv['p'], self.pv['p'] << self.pv_unit['p'])

pv2 = self.PositionVelocity(self.pv, 'AU,AU/day')
assert_array_equal(pv2['p'], self.pv['p'] << u.AU)

def test_error_on_non_equivalent_unit(self):
with pytest.raises(u.UnitsError):
self.PositionVelocity(self.pv, 'AU')
with pytest.raises(u.UnitsError):
self.PositionVelocity(self.pv, 'AU,yr')


class TestStructuredLogUnit:
def setup_class(self):
self.mag_time_dtype = np.dtype([('mag', 'f8'), ('t', 'f8')])
Expand Down

0 comments on commit 4705573

Please sign in to comment.