Skip to content

Commit

Permalink
Add check for time series ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Mar 13, 2024
1 parent be5e2a3 commit ff158e0
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 35 deletions.
13 changes: 12 additions & 1 deletion stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,10 @@ class StingrayTimeseries(StingrayObject):
ephem : str
The JPL ephemeris used to barycenter the data, if any (e.g. DE430)
skip_checks : bool
Skip checks on the time array. Useful when the user is reasonably sure that the
input data are valid.
**other_kw :
Used internally. Any other keyword arguments will be set as attributes of the object.
Expand Down Expand Up @@ -1176,6 +1180,7 @@ def __init__(
ephem: str = None,
timeref: str = None,
timesys: str = None,
skip_checks: bool = False,
**other_kw,
):
StingrayObject.__init__(self)
Expand All @@ -1198,6 +1203,12 @@ def __init__(
if self.time.shape[0] != new_arr.shape[0]:
raise ValueError(f"Lengths of time and {kw} must be equal.")
setattr(self, kw, new_arr)
from .utils import is_sorted

if not skip_checks:
if self.time is not None and not is_sorted(self.time):
warnings.warn("The time array is not sorted. Sorting it now.")
self.sort(inplace=True)

@property
def time(self):
Expand Down Expand Up @@ -2151,7 +2162,7 @@ def sort(self, reverse=False, inplace=False):
--------
>>> time = [2, 1, 3]
>>> count = [200, 100, 300]
>>> ts = StingrayTimeseries(time, array_attrs={"counts": count}, dt=1)
>>> ts = StingrayTimeseries(time, array_attrs={"counts": count}, dt=1, skip_checks=True)
>>> ts_new = ts.sort()
>>> ts_new.time
array([1, 2, 3])
Expand Down
8 changes: 7 additions & 1 deletion stingray/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class EventList(StingrayTimeseries):
rmf_file : str, default None
The file name of the RMF file to use for calibration.
skip_checks : bool, default False
Skip checks for the validity of the event list. Use with caution.
**other_kw :
Used internally. Any other keyword arguments will be ignored
Expand Down Expand Up @@ -211,6 +214,7 @@ def __init__(
timeref=None,
timesys=None,
rmf_file=None,
skip_checks=False,
**other_kw,
):
if ncounts is not None:
Expand Down Expand Up @@ -243,6 +247,7 @@ def __init__(
timeref=timeref,
timesys=timesys,
rmf_file=rmf_file,
skip_checks=skip_checks,
**other_kw,
)

Expand Down Expand Up @@ -502,7 +507,8 @@ def sort(self, inplace=False):
Examples
--------
>>> events = EventList(time=[0, 2, 1], energy=[0.3, 2, 0.5], pi=[3, 20, 5])
>>> events = EventList(time=[0, 2, 1], energy=[0.3, 2, 0.5], pi=[3, 20, 5],
... skip_checks=True)
>>> e1 = events.sort()
>>> assert np.allclose(e1.time, [0, 1, 2])
>>> assert np.allclose(e1.energy, [0.3, 0.5, 2])
Expand Down
1 change: 1 addition & 0 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
:class::class:`Lightcurve` is used to create light curves out of photon counting data
or to save existing light curves in a class that's easy to use.
"""

import os
import logging
import warnings
Expand Down
44 changes: 26 additions & 18 deletions stingray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,10 @@ def test_sort(self):
bleh = [4, 1, 2, 0.5]
mjdref = 57000

lc = StingrayTimeseries(
times, array_attrs={"blah": blah, "_bleh": bleh}, dt=1, mjdref=mjdref
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
lc = StingrayTimeseries(
times, array_attrs={"blah": blah, "_bleh": bleh}, dt=1, mjdref=mjdref
)

lc_new = lc.sort()

Expand Down Expand Up @@ -1223,12 +1224,14 @@ def test_non_overlapping_join_infer(self):

def test_overlapping_join_infer(self):
"""Join two non-overlapping event lists."""
ts = StingrayTimeseries(
time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]]
)
ts_other = StingrayTimeseries(
time=[5.1, 7, 6.1, 6.11, 10.1], energy=[2, 3, 8, 1, 2], gti=[[5, 7], [8, 10]]
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ts = StingrayTimeseries(
time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]]
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ts_other = StingrayTimeseries(
time=[5.1, 7, 6.1, 6.11, 10.1], energy=[2, 3, 8, 1, 2], gti=[[5, 7], [8, 10]]
)
ts_new = ts.join(ts_other, strategy="infer")

assert (ts_new.time == np.array([1, 1.1, 5, 5.1, 6, 6.1, 6.11, 7, 10, 10.1])).all()
Expand All @@ -1237,15 +1240,20 @@ def test_overlapping_join_infer(self):

def test_overlapping_join_change_mjdref(self):
"""Join two non-overlapping event lists."""
ts = StingrayTimeseries(
time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]], mjdref=57001
)
ts_other = StingrayTimeseries(
time=np.asarray([5.1, 7, 6.1, 6.11, 10.1]) + 86400,
energy=[2, 3, 8, 1, 2],
gti=np.asarray([[5, 7], [8, 10]]) + 86400,
mjdref=57000,
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ts = StingrayTimeseries(
time=[1, 1.1, 10, 6, 5],
energy=[10, 6, 3, 11, 2],
gti=[[1, 3], [5, 6]],
mjdref=57001,
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ts_other = StingrayTimeseries(
time=np.asarray([5.1, 7, 6.1, 6.11, 10.1]) + 86400,
energy=[2, 3, 8, 1, 2],
gti=np.asarray([[5, 7], [8, 10]]) + 86400,
mjdref=57000,
)
with pytest.warns(UserWarning, match="Attribute mjdref is different"):
ts_new = ts.join(ts_other, strategy="intersection")

Expand Down
33 changes: 20 additions & 13 deletions stingray/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,12 @@ def test_non_overlapping_join_infer(self):

def test_overlapping_join_infer(self):
"""Join two non-overlapping event lists."""
ev = EventList(time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]])
ev_other = EventList(
time=[5.1, 7, 6.1, 6.11, 10.1], energy=[2, 3, 8, 1, 2], gti=[[5, 7], [8, 10]]
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ev = EventList(time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]])
with pytest.warns(UserWarning, match="The time array is not sorted."):
ev_other = EventList(
time=[5.1, 7, 6.1, 6.11, 10.1], energy=[2, 3, 8, 1, 2], gti=[[5, 7], [8, 10]]
)
ev_new = ev.join(ev_other, strategy="infer")

assert (ev_new.time == np.array([1, 1.1, 5, 5.1, 6, 6.1, 6.11, 7, 10, 10.1])).all()
Expand All @@ -500,15 +502,20 @@ def test_overlapping_join_infer(self):

def test_overlapping_join_change_mjdref(self):
"""Join two non-overlapping event lists."""
ev = EventList(
time=[1, 1.1, 10, 6, 5], energy=[10, 6, 3, 11, 2], gti=[[1, 3], [5, 6]], mjdref=57001
)
ev_other = EventList(
time=np.asarray([5.1, 7, 6.1, 6.11, 10.1]) + 86400,
energy=[2, 3, 8, 1, 2],
gti=np.asarray([[5, 7], [8, 10]]) + 86400,
mjdref=57000,
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ev = EventList(
time=[1, 1.1, 10, 6, 5],
energy=[10, 6, 3, 11, 2],
gti=[[1, 3], [5, 6]],
mjdref=57001,
)
with pytest.warns(UserWarning, match="The time array is not sorted."):
ev_other = EventList(
time=np.asarray([5.1, 7, 6.1, 6.11, 10.1]) + 86400,
energy=[2, 3, 8, 1, 2],
gti=np.asarray([[5, 7], [8, 10]]) + 86400,
mjdref=57000,
)
with pytest.warns(UserWarning, match="Attribute mjdref is different"):
ev_new = ev.join(ev_other, strategy="intersection")

Expand Down
4 changes: 2 additions & 2 deletions stingray/tests/test_lombscargle.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def test_raise_on_invalid_function(self, func_name):
func()

def test_no_dt(self):
el1 = EventList(self.lc1.counts, self.lc1.time, dt=None)
el2 = EventList(self.lc2.counts, self.lc2.time, dt=None)
el1 = EventList(self.lc1.time, self.lc1.counts, dt=None)
el2 = EventList(self.lc2.time, self.lc2.counts, dt=None)
with pytest.raises(ValueError):
lscs = LombScargleCrossspectrum(el1, el2)

Expand Down

0 comments on commit ff158e0

Please sign in to comment.