Skip to content

Commit

Permalink
test timestamps in cut plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
JoranAngevaare committed Jul 21, 2020
1 parent ab29d7f commit 3673c28
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions tests/test_cut_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Initialize. We test both dt time-fields and time time-field
_dtype_name = 'var'
_cut_dtype = ('variable 0', _dtype_name)
full_dt_dtype = [(_cut_dtype, np.float64)] + strax.time_dt_fields
full_time_dtype = [(_cut_dtype, np.float64)] + strax.time_fields
full_dt_dtype = [(_cut_dtype, np.float64)] + strax.time_dt_fields
full_time_dtype = [(_cut_dtype, np.float64)] + strax.time_fields


def get_some_array():
Expand All @@ -16,7 +16,8 @@ def get_some_array():

# Stolen from testutils.bounds_to_intervals
def bounds_to_intervals(bs, dt=1):
x = np.zeros(len(bs), dtype=full_dt_dtype if take_dt else full_time_dtype)
x = np.zeros(len(bs),
dtype=full_dt_dtype if take_dt else full_time_dtype)
x['time'] = [x[0] for x in bs]
# Remember: exclusive right bound...
if take_dt:
Expand All @@ -26,10 +27,11 @@ def bounds_to_intervals(bs, dt=1):
x['endtime'] = x['time'] + ([x[1] - x[0] for x in bs]) * dt
return x

# Randomly imput either of full_dt_dtype or full_time_dtype
# Randomly input either of full_dt_dtype or full_time_dtype
sorted_intervals = testutils.sorted_bounds().map(bounds_to_intervals)
return sorted_intervals


@given(get_some_array().filter(lambda x: len(x) >= 0),
strategies.integers(min_value=-10, max_value=10))
@settings(deadline=None)
Expand All @@ -42,7 +44,7 @@ def bounds_to_intervals(bs, dt=1):
(11, 5, 7),
(7, 7, 9)
],
dtype=[(_cut_dtype, np.float64)] + strax.time_fields),
dtype=[(_cut_dtype, np.float64)] + strax.time_fields),
cut_threshold=5)
@example(
input_peaks=np.array(
Expand All @@ -51,7 +53,7 @@ def bounds_to_intervals(bs, dt=1):
(5, 2, 2, 1),
(11, 4, 2, 4)
],
dtype=[(_cut_dtype, np.int16)] + strax.time_dt_fields),
dtype=[(_cut_dtype, np.int16)] + strax.time_dt_fields),
cut_threshold=-1)
def test_cut_plugin(input_peaks, cut_threshold):
"""
Expand All @@ -71,8 +73,10 @@ def compute(self, chunk_i):
data = chunks[chunk_i]
return self.chunk(
data=data,
start=int(data[0]['time']) if len(data) else np.arange(len(chunks))[chunk_i],
end=int(strax.endtime(data[-1])) if len(data) else np.arange(1, len(chunks)+1)[chunk_i])
start=(int(data[0]['time']) if len(data)
else np.arange(len(chunks))[chunk_i]),
end=(int(strax.endtime(data[-1])) if len(data)
else np.arange(1, len(chunks) + 1)[chunk_i]))

# Hack to make peak output stop after a few chunks
def is_ready(self, chunk_i):
Expand All @@ -97,4 +101,13 @@ def cut_by(self, to_be_cut):
targets=strax.camel_to_snake(CutSomething.__name__))
correct_answer = np.sum(input_peaks[_dtype_name] > cut_threshold)
assert len(result) == len(input_peaks), "WTF??"
assert correct_answer == np.sum(result['cut_something']), "Cut plugin does not give boolean arrays correctly"
assert correct_answer == np.sum(result['cut_something']), (
"Cut plugin does not give boolean arrays correctly")

if len(input_peaks):
assert strax.endtime(input_peaks).max() == \
strax.endtime(result).max(), "last end time got scrambled"
assert np.all(input_peaks['time'] ==
result['time']), "(start) times got scrambled"
assert np.all(strax.endtime(input_peaks) ==
strax.endtime(result)), "Some end times got scrambled"

0 comments on commit 3673c28

Please sign in to comment.