Skip to content

Commit

Permalink
Merge pull request #822 from StingraySoftware/fix_analyze_segments
Browse files Browse the repository at this point in the history
Fix operation on segment when result is NaN
  • Loading branch information
matteobachetti committed May 7, 2024
2 parents 379ed3a + 16861d1 commit ca74d3c
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/changes/822.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix case when analyze_segments has an invalid segment
2 changes: 1 addition & 1 deletion docs/notebooks
Submodule notebooks updated 0 files
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ filterwarnings =
ignore:.*datetime.datetime.utcfromtimestamp.*:DeprecationWarning
ignore:.*__array_wrap__ must accept context and return_scalar arguments:DeprecationWarning
ignore:.*Pyarrow:
ignore:.*Creating an ndarray from ragged nested sequences:

;addopts = --disable-warnings

Expand Down
40 changes: 29 additions & 11 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,8 @@ def estimate_segment_size(self, min_counts=None, min_samples=None, even_sampling
def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
"""Analyze segments of the light curve with any function.
Intervals with less than one data point are skipped.
Parameters
----------
func : function
Expand Down Expand Up @@ -2609,8 +2611,10 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
Lower time boundaries of all time segments.
stop_times : array
upper time boundaries of all segments.
result : array of N elements
The result of ``func`` for each segment of the light curve
result : list of N elements
The result of ``func`` for each segment of the light curve. If the function
returns multiple outputs, they are returned as a list of arrays.
If a given interval has not enough data for a calculation, ``None`` is returned.
Examples
--------
Expand Down Expand Up @@ -2649,23 +2653,37 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
stop = np.searchsorted(self.time, stop_times)

results = []

n_outs = 1
for i, (st, sp, tst, tsp) in enumerate(zip(start, stop, start_times, stop_times)):
if sp - st <= 1:
warnings.warn(
f"Segment {i} ({tst}--{tsp}) has one data point or less. Skipping it."
f"Segment {i} ({tst}--{tsp}) has one data point or less. Skipping it "
)
res = np.nan
else:
lc_filt = self[st:sp]
lc_filt.gti = np.asanyarray([[tst, tsp]])
continue
lc_filt = self[st:sp]
lc_filt.gti = np.asanyarray([[tst, tsp]])

res = func(lc_filt, **kwargs)
res = func(lc_filt, **kwargs)
results.append(res)
if isinstance(res, Iterable) and not isinstance(res, str):
n_outs = len(res)

# If the function returns multiple outputs, we need to separate them

if n_outs > 1:
outs = [[] for _ in range(n_outs)]
for res in results:
for i in range(n_outs):
outs[i].append(res[i])
results = outs

results = np.array(results)
# Try to transform into a (possibly multi-dimensional) numpy array
try:
results = np.array(results)
except ValueError: # pragma: no cover
pass

if len(results.shape) == 2:
results = [results[:, i] for i in range(results.shape[1])]
return start_times, stop_times, results

def analyze_by_gti(self, func, fraction_step=1, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions stingray/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,9 @@ def color(ev):
en1_ct = np.count_nonzero(mask1)
en2_ct = np.count_nonzero(mask2)

if en1_ct == 0 or en2_ct == 0:
warnings.warn("No counts in one of the energy ranges. Returning NaN")
return np.nan, np.nan
color = en2_ct / en1_ct
color_err = color * (np.sqrt(en1_ct) / en1_ct + np.sqrt(en2_ct) / en2_ct)
return color, color_err
Expand Down
11 changes: 6 additions & 5 deletions stingray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,17 +1497,18 @@ def test_estimate_segment_size_lower_dt(self):

assert ts.estimate_segment_size(100, min_samples=40) == 8.0

def test_analyze_segments_bad_intv(self):
@pytest.mark.parametrize("n_outs", [0, 1, 2, 3])
def test_analyze_segments_bad_intv(self, n_outs):
ts = StingrayTimeseries(time=np.arange(10), dt=1, gti=[[-0.5, 0.5], [1.5, 10.5]])

def func(x):
return np.size(x.time)
if n_outs == 0:
return np.size(x.time)
return [np.size(x.time) for _ in range(n_outs)]

# I do not specify the segment_size, which means results will be calculated per-GTI
with pytest.warns(UserWarning, match="has one data point or less."):
_, _, results = ts.analyze_segments(func, segment_size=None)
# the first GTI contains only one bin, the result will be invalid
assert np.isnan(results[0])
ts.analyze_segments(func, segment_size=None)

def test_analyze_segments_by_gti(self):
ts = StingrayTimeseries(time=np.arange(11), dt=1, gti=[[-0.5, 5.5], [6.5, 10.5]])
Expand Down
6 changes: 6 additions & 0 deletions stingray/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,12 @@ def test_colors(self):
assert np.allclose(start, np.arange(10) * 10000)
assert np.allclose(stop, np.arange(1, 11) * 10000)

def test_colors_missing_energies(self):
events = copy.deepcopy(self.events)
events.filter_energy_range([0, 3], inplace=True)
with pytest.warns(UserWarning, match="No counts in one of the energy ranges"):
events.get_color_evolution([[0, 3], [4, 6]], 10000)

def test_colors_no_segment(self):
start, stop, colors, color_errs = self.events.get_color_evolution([[0, 3], [4, 6]])
# 50000 / 50000 = 1
Expand Down

0 comments on commit ca74d3c

Please sign in to comment.