Skip to content

Commit

Permalink
Test all formats and fix bug for ymdhms
Browse files Browse the repository at this point in the history
  • Loading branch information
mhvk committed Aug 26, 2023
1 parent f2fd522 commit 65ebd26
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
14 changes: 10 additions & 4 deletions astropy/time/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3267,20 +3267,26 @@ def get_as_filled_ndarray(mask, val, masked_cls):
filled = val.unmasked
fill_value = np.zeros_like(filled, shape=())

if np.any(val.mask):
# For structured dtype, the mask is structured too. We consider an
# array element masked if any field of the structure is masked.
if val.dtype.names:
val_mask = val.mask != np.zeros_like(val.mask, shape=())
else:
val_mask = val.mask
if np.any(val_mask):
# We're going to fill masked values, so make a copy.
filled = filled.copy()

# Final mask is the logical-or of inputs
mask = mask | val.mask
mask = mask | val_mask

# First unmasked element. If all elements are masked then
# use fill_value from above. For MaskedArray, this uses val.fill_value,
# so all will be fine as long as the user has set this appropriately.
if filled.size > 1:
first_unmasked = val.mask.argmin()
first_unmasked = val_mask.argmin()
# Result indexes first False, or first item if all True.
if first_unmasked > 0 or not val.mask.flat[0]:
if first_unmasked > 0 or not val_mask.flat[0]:
fill_value = filled.flat[first_unmasked]

filled[mask] = fill_value
Expand Down
32 changes: 32 additions & 0 deletions astropy/time/tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,35 @@ def test_serialize_ecsv_masked(serialize_method, tmp_path):
# Serializing formatted_value loses some precision.
atol = 0.1 * u.us if serialize_method == "formatted_value" else 1 * u.ps
assert np.all(abs(t2["col0"] - t["col0"]) <= atol)


@pytest.mark.parametrize("format_", Time.FORMATS)
@pytest.mark.parametrize("masked_cls", [np.ma.MaskedArray, Masked])
def test_all_formats(format_, masked_cls):
mjd = np.array([55000.25, 55000.375, 55001.125])
mask = np.array([True, False, False])
mjdm = masked_cls(mjd, mask=mask)
t = Time(mjd, format="mjd")
tm = Time(mjdm, format="mjd")
assert tm.masked and np.all(tm.mask == mask)

# Get values in the given format, check that these have the appropriate class
# and are correct (ignoring masked ones, which get adjusted on Time
# initialization, in core._check_for_masked_and_fill).
t_format = getattr(t, format_)
tm_format = getattr(tm, format_)
assert isinstance(tm_format, masked_cls)
if format_ == "ymdhms" and masked_cls is np.ma.MaskedArray:
# Work around https://github.com/numpy/numpy/issues/24554
# TODO: just compare with t_format once resolved.
expected = masked_cls(t_format, mask=mask)
else:
expected = t_format
assert np.all(tm_format == expected)

# Verify that we can also initialize with the format and that this gives
# the right result and mask too.
t2 = Time(t_format, format=format_)
tm2 = Time(tm_format, format=format_)
assert tm2.masked and np.all(tm2.mask == mask)
assert np.all(tm2 == t2)

0 comments on commit 65ebd26

Please sign in to comment.