-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more tests to strax #359
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8d687b9
fix loop plugin test
JoranAngevaare 4f070ae
Merge branch 'master' into loop_plugin_multioutput
JoranAngevaare ef07404
tweak loop_plugin tests
JoranAngevaare 5ca372e
Merge branch 'loop_plugin_multioutput' of https://github.com/AxFounda…
JoranAngevaare 2c64f76
also test peak splitting
JoranAngevaare 985ccc5
add tests for peak_properties
JoranAngevaare c106d6e
remove silly test
JoranAngevaare 76276c0
address review comments
JoranAngevaare 164a252
Merge branch 'master' into loop_plugin_multioutput
JoranAngevaare File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
import tempfile | ||
import os | ||
import os.path as osp | ||
|
||
import pytest | ||
|
||
from strax.testutils import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,87 +13,33 @@ def rechunk_array_to_arrays(array, n: int): | |
yield array[i:i + n] | ||
|
||
|
||
def drop_random(chunks: list) -> list: | ||
def drop_random(input_array: np.ndarray) -> np.ndarray: | ||
""" | ||
Drop some of the data in the chunks | ||
:param chunks: list op numpy arrays to modify. Here we will drop some of the fields randomly | ||
:return: list of chunks | ||
Drop some of the data in the input array | ||
:param input_array: numpy array to modify. Here we will drop some | ||
of the indices in the array randomly | ||
:return: random selection of the input data | ||
""" | ||
res = [] | ||
for chunk in chunks: | ||
if len(chunk) > 1: | ||
# We are going to keep this many items in this chunk | ||
keep_n = np.random.randint(1, len(chunk)+1) | ||
# These are the indices we will keep (only keep unique ones) | ||
keep_indices = np.random.randint(0, len(chunk)-1, keep_n) | ||
keep_indices = np.unique(keep_indices) | ||
keep_indices.sort() | ||
|
||
# This chunk will now be reduced using only keep_indices | ||
d = chunk[keep_indices] | ||
res.append(d) | ||
return res | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) >= 0), | ||
strategies.integers(min_value=1, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[(0, 0, 1, 1), | ||
(1, 1, 1, 1), | ||
(5, 2, 2, 1), | ||
(11, 4, 2, 4)], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_loop_plugin(big_data, nchunks): | ||
"""Test the loop plugin for random data""" | ||
_loop_test_inner(big_data, nchunks) | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) >= 0), | ||
strategies.integers(min_value=1, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[(0, 0, 1, 1), | ||
(1, 1, 1, 1), | ||
(5, 2, 2, 1), | ||
(11, 4, 2, 4)], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_loop_plugin_multi_output(big_data, nchunks,): | ||
""" | ||
Test the loop plugin for random data where it should give multiple | ||
outputs | ||
""" | ||
_loop_test_inner(big_data, nchunks, target='other_combined_things') | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) == 0), | ||
strategies.integers(min_value=2, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_value_error_for_loop_plugin(big_data, nchunks): | ||
"""Make sure that we are are getting the right ValueError""" | ||
try: | ||
_loop_test_inner(big_data, nchunks, force_value_error=True) | ||
raise RuntimeError( | ||
'did not run into ValueError despite the fact we are having ' | ||
'multiple none-type chunks') | ||
except ValueError: | ||
# Good we got the ValueError we wanted | ||
pass | ||
if len(input_array) > 1: | ||
# We are going to keep this many items in this array | ||
keep_n = np.random.randint(1, len(input_array) + 1) | ||
# These are the indices we will keep (only keep unique ones) | ||
keep_indices = list(np.random.randint(0, len(input_array) - 1, keep_n)) | ||
keep_indices = np.unique(keep_indices) | ||
keep_indices.sort() | ||
# This chunk will now be reduced using only keep_indices | ||
return input_array[keep_indices] | ||
else: | ||
return input_array | ||
|
||
|
||
def _loop_test_inner(big_data, nchunks, target='added_thing', force_value_error=False): | ||
def _loop_test_inner(big_data, | ||
nchunks, | ||
target='added_thing', | ||
force_value_error=False): | ||
""" | ||
Test loop plugins for random data. For this test we are going to | ||
setup to plugins that will be looped over and combined into a loop | ||
setup two plugins that will be looped over and combined into a loop | ||
plugin (depending on the target, this may be a multi output plugin). | ||
|
||
We are going to setup as follows: | ||
|
@@ -112,12 +58,8 @@ def _loop_test_inner(big_data, nchunks, target='added_thing', force_value_error= | |
|
||
_dtype = big_data.dtype | ||
|
||
# TODO smarter test. I want to drop some random data from the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L39 typo to -> two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks |
||
# small_chunks but this does not work yet. Perhaps related to | ||
# https://github.com/AxFoundation/strax/pull/345 (will fix in that | ||
# PR) | ||
# small_chunks = drop_random(big_chunks.copy()) # What I want to do | ||
small_chunks = big_chunks | ||
# Keep track fo the chunks seen in BigThing | ||
_big_chunks_seen = [] | ||
|
||
class BigThing(strax.Plugin): | ||
"""Plugin that provides data for looping over""" | ||
|
@@ -128,14 +70,17 @@ class BigThing(strax.Plugin): | |
|
||
def compute(self, chunk_i): | ||
data = big_chunks[chunk_i] | ||
chunk = self.chunk( | ||
data=data, | ||
start=( | ||
int(data[0]['time']) if len(data) | ||
else np.arange(len(big_chunks))[chunk_i]), | ||
end=( | ||
int(strax.endtime(data[-1])) if len(data) | ||
else np.arange(1, len(big_chunks) + 1)[chunk_i])) | ||
# First determine start (t0) and stop (t1) times for the chunk | ||
if chunk_i == 0: | ||
t0 = int(data[0]['time']) if chunk_i > 0 else 0 | ||
t1 = int(strax.endtime(data[-1])) if len(data) else 1 | ||
else: | ||
# Just take the previous chunk and take that as start time | ||
t0 = _big_chunks_seen[chunk_i-1].end | ||
t1 = int(strax.endtime(data[-1]) if len(data) else t0 + 1) | ||
|
||
chunk = self.chunk(data=data, start=t0, end=t1) | ||
_big_chunks_seen.append(chunk) | ||
return chunk | ||
|
||
def is_ready(self, chunk_i): | ||
|
@@ -146,51 +91,40 @@ def source_finished(self): | |
return True | ||
|
||
class SmallThing(strax.CutPlugin): | ||
"""Minimal working example of CutPlugin""" | ||
depends_on = tuple() | ||
"""Throw away some of the data in big_thing""" | ||
depends_on = 'big_thing' | ||
provides = 'small_thing' | ||
data_kind = 'small_kinda_data' | ||
dtype = _dtype | ||
rechunk_on_save = True | ||
|
||
def compute(self, chunk_i): | ||
data = small_chunks[chunk_i] | ||
chunk = self.chunk( | ||
data=data, | ||
start=( | ||
int(data[0]['time']) if len(data) | ||
else np.arange(len(small_chunks))[chunk_i]), | ||
end=( | ||
int(strax.endtime(data[-1])) if len(data) | ||
else np.arange(1, len(small_chunks) + 1)[chunk_i])) | ||
return chunk | ||
|
||
def is_ready(self, chunk_i): | ||
# Hack to make peak output stop after a few chunks | ||
return chunk_i < len(small_chunks) | ||
|
||
def source_finished(self): | ||
return True | ||
def compute(self, big_kinda_data): | ||
# Drop some of the data in big_kinda_data | ||
return drop_random(big_kinda_data) | ||
|
||
class AddBigToSmall(strax.LoopPlugin): | ||
""" | ||
Test loop plugin by looping big_thing and adding whatever is in small_thing | ||
""" | ||
depends_on = 'big_thing', 'small_thing' | ||
provides = 'added_thing' | ||
loop_over = 'big_thing' # Also just test this feature | ||
loop_over = 'big_kinda_data' # Also just test this feature | ||
|
||
def infer_dtype(self): | ||
# Get the dtype from the dependency | ||
return self.deps['big_thing'].dtype | ||
|
||
def compute(self, big_kinda_data, small_kinda_data): | ||
res = np.zeros(len(big_kinda_data), dtype=self.dtype) | ||
for k in res.dtype.names: | ||
def compute_loop(self, big_kinda_data, small_kinda_data): | ||
res = {} | ||
for k in self.dtype.names: | ||
if k == _dtype_name: | ||
res[k] = big_kinda_data[k] | ||
for small_bit in small_kinda_data[k]: | ||
for i in range(len(res[k])): | ||
res[k][i] += small_bit | ||
if np.iterable(res[k]): | ||
for i in range(len(res[k])): | ||
res[k][i] += small_bit | ||
else: | ||
res[k] += small_bit | ||
else: | ||
res[k] = big_kinda_data[k] | ||
return res | ||
|
@@ -205,22 +139,85 @@ def infer_dtype(self): | |
# NB! This should be a dict for the kind of provide arguments | ||
return {k: self.deps['big_thing'].dtype for k in self.provides} | ||
|
||
def compute(self, big_kinda_data, small_kinda_data): | ||
res = np.zeros(len(big_kinda_data), _dtype) | ||
for k in res.dtype.names: | ||
def compute_loop(self, big_kinda_data, small_kinda_data): | ||
res = {} | ||
for k in self.dtype['some_combined_things'].names: | ||
if k == _dtype_name: | ||
res[k] = big_kinda_data[k] | ||
for small_bit in small_kinda_data[k]: | ||
for i in range(len(res[k])): | ||
res[k][i] += small_bit | ||
if np.iterable(res[k]): | ||
for i in range(len(res[k])): | ||
res[k][i] += small_bit | ||
else: | ||
res[k] += small_bit | ||
else: | ||
res[k] = big_kinda_data[k] | ||
return {k: res for k in self.provides} | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
st = strax.Context(storage=[strax.DataDirectory(temp_dir)]) | ||
st.register((BigThing, SmallThing, AddBigToSmall, AddBigToSmallMultiOutput)) | ||
|
||
# Make small thing in order to allow re-chunking | ||
st.make(run_id='some_run', targets='small_thing') | ||
|
||
# Make the loop plugin | ||
result = st.get_array(run_id='some_run', targets=target) | ||
assert np.shape(result) == np.shape(big_data), 'Looping over big_data resulted in a different datasize?!' | ||
assert np.sum(result[_dtype_name]) >= np.sum(big_data[_dtype_name]), "Result should be at least as big as big_data because we added small_data data" | ||
assert isinstance(result, np.ndarray), "Result is not ndarray?" | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) >= 0), | ||
strategies.integers(min_value=1, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[(0, 0, 1, 1), | ||
(1, 1, 1, 1), | ||
(5, 2, 2, 1), | ||
(11, 4, 2, 4)], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_loop_plugin(big_data, nchunks): | ||
"""Test the loop plugin for random data""" | ||
_loop_test_inner(big_data, nchunks) | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) >= 0), | ||
strategies.integers(min_value=1, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[(0, 0, 1, 1), | ||
(1, 1, 1, 1), | ||
(5, 2, 2, 1), | ||
(11, 4, 2, 4)], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_loop_plugin_multi_output(big_data, nchunks,): | ||
""" | ||
Test the loop plugin for random data where it should give multiple | ||
outputs | ||
""" | ||
_loop_test_inner(big_data, nchunks, target='other_combined_things') | ||
|
||
|
||
@given(get_some_array().filter(lambda x: len(x) == 0), | ||
strategies.integers(min_value=2, max_value=10)) | ||
@settings(deadline=None) | ||
@example( | ||
big_data=np.array( | ||
[], | ||
dtype=full_dt_dtype), | ||
nchunks=2) | ||
def test_value_error_for_loop_plugin(big_data, nchunks): | ||
"""Make sure that we are are getting the right ValueError""" | ||
try: | ||
_loop_test_inner(big_data, nchunks, force_value_error=True) | ||
raise RuntimeError( | ||
'did not run into ValueError despite the fact we are having ' | ||
'multiple none-type chunks') | ||
except ValueError: | ||
# Good we got the ValueError we wanted | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there probably was an import statement during some point of the live time of this PR