Skip to content

Commit

Permalink
Merge pull request #753 from StingraySoftware/fix_truncate
Browse files Browse the repository at this point in the history
Update tstart and tseg when using Lightcurve.truncate()
  • Loading branch information
mgullik committed Sep 6, 2023
2 parents 5d633c1 + c3790bc commit 5b338dc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/changes/753.bugfix.rst
@@ -0,0 +1 @@
Update tstart and tseg when using Lightcurve.truncate()
7 changes: 5 additions & 2 deletions stingray/lightcurve.py
Expand Up @@ -1243,9 +1243,12 @@ def truncate(self, start=0, stop=None, method="index"):
raise ValueError("Unknown method type " + method + ".")

if method.lower() == "index":
return self._truncate_by_index(start, stop)
new_lc = self._truncate_by_index(start, stop)
else:
return self._truncate_by_time(start, stop)
new_lc = self._truncate_by_time(start, stop)
new_lc.tstart = new_lc.gti[0, 0]
new_lc.tseg = new_lc.gti[-1, 1] - new_lc.gti[0, 0]
return new_lc

def _truncate_by_index(self, start, stop):
"""Private method for truncation using index values."""
Expand Down
66 changes: 47 additions & 19 deletions stingray/tests/test_lightcurve.py
Expand Up @@ -308,8 +308,8 @@ class TestLightcurve(object):
@classmethod
def setup_class(cls):
cls.times = np.array([1, 2, 3, 4])
cls.counts = np.array([2, 2, 2, 2])
cls.counts_err = np.array([0.2, 0.2, 0.2, 0.2])
cls.counts = np.array([2, 4, 6, 8])
cls.counts_err = np.array([0.2, 0.4, 0.6, 0.8])
cls.bg_counts = np.array([1, 0, 0, 1])
cls.bg_ratio = np.array([1, 1, 0.5, 1])
cls.frac_exp = np.array([1, 1, 1, 1])
Expand Down Expand Up @@ -632,7 +632,7 @@ def test_subtraction(self):

lc = lc2 - lc1

expected_counts = np.array([1, 2, 3, 4])
expected_counts = np.array([1, 0, -1, -2])
assert np.allclose(lc.counts, expected_counts)
assert lc1.mjdref == lc.mjdref

Expand All @@ -658,7 +658,9 @@ def test_indexing_with_unexpected_type(self):
def test_indexing(self):
lc = Lightcurve(self.times, self.counts)

assert lc[0] == lc[1] == lc[2] == lc[3] == 2
assert lc[0] == 2
assert lc[1] == 4
assert lc[3] == 8

def test_slicing(self):
lc = Lightcurve(
Expand All @@ -669,20 +671,20 @@ def test_slicing(self):
err=self.counts / 10,
err_dist="gauss",
)
assert np.allclose(lc[1:3].counts, np.array([2, 2]))
assert np.allclose(lc[:2].counts, np.array([2, 2]))
assert np.allclose(lc[1:3].counts, np.array([4, 6]))
assert np.allclose(lc[:2].counts, np.array([2, 4]))
assert np.allclose(lc[:2].gti, [[0.5, 2.5]])
assert np.allclose(lc[2:].counts, np.array([2, 2]))
assert np.allclose(lc[2:].counts, np.array([6, 8]))
assert np.allclose(lc[2:].gti, [[2.5, 4.5]])
assert np.allclose(lc[:].counts, np.array([2, 2, 2, 2]))
assert np.allclose(lc[:].counts, np.array([2, 4, 6, 8]))
assert np.allclose(lc[::2].gti, [[0.5, 1.5], [2.5, 3.5]])
assert np.allclose(lc[:].gti, lc.gti)
assert lc[:].mjdref == lc.mjdref
assert lc[::2].n == 2
assert np.allclose(lc[1:3].counts_err, np.array([2, 2]) / 10)
assert np.allclose(lc[:2].counts_err, np.array([2, 2]) / 10)
assert np.allclose(lc[2:].counts_err, np.array([2, 2]) / 10)
assert np.allclose(lc[:].counts_err, np.array([2, 2, 2, 2]) / 10)
assert np.allclose(lc[1:3].counts_err, np.array([0.4, 0.6]))
assert np.allclose(lc[:2].counts_err, np.array([0.2, 0.4]))
assert np.allclose(lc[2:].counts_err, np.array([0.6, 0.8]))
assert np.allclose(lc[:].counts_err, np.array([0.2, 0.4, 0.6, 0.8]))
assert lc[:].err_dist == lc.err_dist

def test_index(self):
Expand Down Expand Up @@ -751,7 +753,8 @@ def test_join_disjoint_time_arrays(self):
lc = lc1.join(lc2)

assert len(lc.counts) == len(lc.time) == 8
assert np.allclose(lc.counts, 2)
assert np.allclose(lc.counts[4:], 2)
assert np.allclose(lc.counts[:4], self.counts)
assert lc.mjdref == lc1.mjdref

def test_join_overlapping_time_arrays(self):
Expand All @@ -765,7 +768,7 @@ def test_join_overlapping_time_arrays(self):
lc = lc1.join(lc2)

assert len(lc.counts) == len(lc.time) == 6
assert np.allclose(lc.counts, np.array([2, 2, 3, 3, 4, 4]))
assert np.allclose(lc.counts, np.array([2, 4, 5, 6, 4, 4]))

def test_join_different_err_dist_disjoint_times(self):
_times = [5, 6, 7, 8]
Expand Down Expand Up @@ -808,21 +811,30 @@ def test_truncate_by_index(self):

lc1 = lc.truncate(start=1)
assert np.allclose(lc1.time, np.array([2, 3, 4]))
assert np.allclose(lc1.counts, np.array([2, 2, 2]))
assert np.allclose(lc1.counts, np.array([4, 6, 8]))
assert np.allclose(lc1.countrate, np.array([4, 6, 8]))
assert np.allclose(lc1.bg_counts, np.array([0, 0, 1]))
assert np.allclose(lc1.bg_ratio, np.array([1, 0.5, 1]))
assert np.allclose(lc1.frac_exp, np.array([1, 1, 1]))
np.testing.assert_almost_equal(lc1.gti[0][0], 1.5)
assert lc1.mjdref == lc.mjdref
assert lc1.tstart == 1.5
assert lc1.tseg == 3
assert lc1.n == 3

lc2 = lc.truncate(stop=2)
assert np.allclose(lc2.time, np.array([1, 2]))
assert np.allclose(lc2.counts, np.array([2, 2]))
assert np.allclose(lc2.counts, np.array([2, 4]))
assert np.allclose(lc2.countrate, np.array([2, 4]))
assert np.allclose(lc2.bg_counts, np.array([1, 0]))
assert np.allclose(lc2.bg_ratio, np.array([1, 1]))
assert np.allclose(lc2.frac_exp, np.array([1, 1]))
np.testing.assert_almost_equal(lc2.gti[-1][-1], 2.5)
assert lc2.mjdref == lc.mjdref
assert lc2.n == 2

assert lc2.tstart == lc.tstart
assert lc2.tseg == 2

def test_truncate_by_time_stop_less_than_start(self):
lc = Lightcurve(self.times, self.counts)
Expand All @@ -836,19 +848,35 @@ def test_truncate_fails_with_incorrect_method(self):
lc1 = lc.truncate(start=1, method="wrong")

def test_truncate_by_time(self):
lc = Lightcurve(self.times, self.counts, gti=self.gti)
lc = Lightcurve(self.times, self.counts, err=self.counts_err, gti=self.gti)
# make sure they are initialized
lc.meancounts, lc.meanrate, lc.n

lc1 = lc.truncate(start=1, method="time")
assert np.allclose(lc1.time, np.array([1, 2, 3, 4]))
assert np.allclose(lc1.counts, np.array([2, 2, 2, 2]))
assert np.allclose(lc1.counts, np.array([2, 4, 6, 8]))
assert np.allclose(lc1.counts_err, np.array([0.2, 0.4, 0.6, 0.8]))
assert np.allclose(lc1.countrate, np.array([2, 4, 6, 8]))
np.testing.assert_almost_equal(lc1.gti[0][0], 0.5)
assert lc1.mjdref == lc.mjdref
assert lc1.tstart == 0.5
assert lc1.tseg == 4.0
assert lc1.meancounts == 5
assert lc1.meanrate == 5
assert lc1.n == 4

lc2 = lc.truncate(stop=3, method="time")
assert np.allclose(lc2.time, np.array([1, 2]))
assert np.allclose(lc2.counts, np.array([2, 2]))
assert np.allclose(lc2.counts, np.array([2, 4]))
assert np.allclose(lc2.counts_err, np.array([0.2, 0.4]))
assert np.allclose(lc2.countrate, np.array([2, 4]))
np.testing.assert_almost_equal(lc2.gti[-1][-1], 2.5)
assert lc2.mjdref == lc.mjdref
assert lc2.tstart == 0.5
assert lc2.tseg == 2
assert lc2.meancounts == 3
assert lc2.meanrate == 3
assert lc2.n == 2

def test_split_with_two_segments(self):
test_time = np.array([1, 2, 3, 6, 7, 8])
Expand Down

0 comments on commit 5b338dc

Please sign in to comment.