diff --git a/strax/plugins/__init__.py b/strax/plugins/__init__.py index ada9e1ef0..ffbb5b014 100644 --- a/strax/plugins/__init__.py +++ b/strax/plugins/__init__.py @@ -4,3 +4,4 @@ from .merge_only_plugin import * from .overlap_window_plugin import * from .parrallel_source_plugin import * +from .down_chunking_plugin import * diff --git a/strax/plugins/down_chunking_plugin.py b/strax/plugins/down_chunking_plugin.py new file mode 100644 index 000000000..fb34e2b29 --- /dev/null +++ b/strax/plugins/down_chunking_plugin.py @@ -0,0 +1,44 @@ +import strax +from .plugin import Plugin + +export, __all__ = strax.exporter() + + +## +# Plugin which allows to use yield in plugins compute method. +# Allows to chunk down output before storing to disk. +# Only works if multiprocessing is omitted. +## + + +@export +class DownChunkingPlugin(Plugin): + """Plugin that merges data from its dependencies.""" + + parallel = False + + def __init__(self): + super().__init__() + + if self.parallel: + raise NotImplementedError( + f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which ' + "currently does not support parallel processing." + ) + + if self.multi_output: + raise NotImplementedError( + f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which ' + "currently does not support multiple outputs. Please only provide " + "a single data-type." + ) + + def iter(self, iters, executor=None): + return super().iter(iters, executor) + + def _iter_compute(self, chunk_i, **inputs_merged): + return self.do_compute(chunk_i=chunk_i, **inputs_merged) + + def _fix_output(self, result, start, end, _dtype=None): + """Wrapper around _fix_output to support the return of iterators.""" + return result diff --git a/strax/plugins/plugin.py b/strax/plugins/plugin.py index ee95d22c6..afaa9d161 100644 --- a/strax/plugins/plugin.py +++ b/strax/plugins/plugin.py @@ -133,6 +133,8 @@ def __init__(self): # not have to updated save_when self.save_when = immutabledict.fromkeys(self.provides, self.save_when) + if getattr(self, "provides", None): + self.provides = strax.to_str_tuple(self.provides) self.compute_pars = compute_pars self.input_buffer = dict() @@ -492,7 +494,7 @@ class IterDone(Exception): pending_futures = [f for f in pending_futures if not f.done()] yield new_future else: - yield self.do_compute(chunk_i=chunk_i, **inputs_merged) + yield from self._iter_compute(chunk_i=chunk_i, **inputs_merged) except IterDone: # Check all sources are exhausted. @@ -517,6 +519,10 @@ class IterDone(Exception): finally: self.cleanup(wait_for=pending_futures) + def _iter_compute(self, chunk_i, **inputs_merged): + """Either yields or returns strax chunks from the input.""" + yield self.do_compute(chunk_i=chunk_i, **inputs_merged) + def cleanup(self, wait_for): pass # A standard plugin doesn't need to do anything here diff --git a/strax/testutils.py b/strax/testutils.py index 1d2f25e72..eea791e63 100644 --- a/strax/testutils.py +++ b/strax/testutils.py @@ -252,6 +252,58 @@ def compute(self, peaks): return dict(peak_classification=p, lone_hits=lh) +# Plugins with time structure within chunks, +# used to test down chunking within plugin compute. +class RecordsWithTimeStructure(Records): + """Same as Records but with some structure in "time" for testing.""" + + def setup(self): + self.last_end = 0 + + def compute(self, chunk_i): + r = np.zeros(self.config["recs_per_chunk"], self.dtype) + r["time"] = self.last_end + np.arange(self.config["recs_per_chunk"]) + 5 + r["length"] = r["dt"] = 1 + r["channel"] = np.arange(len(r)) + + end = self.last_end + self.config["recs_per_chunk"] + 10 + chunk = self.chunk(start=self.last_end, end=end, data=r) + self.last_end = end + + return chunk + + +class DownSampleRecords(strax.DownChunkingPlugin): + """PLugin to test the downsampling of Chunks during compute. + + Needed for simulations. + + """ + + provides = "records_down_chunked" + depends_on = "records" + dtype = strax.record_dtype() + rechunk_on_save = False + + def compute(self, records, start, end): + offset = 0 + last_start = start + + count = 0 + for count, r in enumerate(records): + if count == 5: + res = records[offset:count] + chunk_end = np.max(strax.endtime(res)) + offset = count + chunk = self.chunk(start=last_start, end=chunk_end, data=res) + last_start = chunk_end + yield chunk + + res = records[offset : count + 1] + chunk = self.chunk(start=last_start, end=end, data=res) + yield chunk + + # Used in test_core.py run_id = "0" diff --git a/tests/test_context.py b/tests/test_context.py index 7602277fa..067aeed17 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,5 +1,11 @@ import strax -from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id +from strax.testutils import ( + Records, + Peaks, + PeaksWoPerRunDefault, + PeakClassification, + run_id, +) import tempfile import numpy as np from hypothesis import given, settings @@ -301,10 +307,10 @@ def test_deregister(self): st.deregister_plugins_with_missing_dependencies() assert st._plugin_class_registry.pop("peaks", None) is None - def get_context(self, use_defaults): + def get_context(self, use_defaults, **kwargs): """Get simple context where we have one mock run in the only storage frontend.""" assert isinstance(use_defaults, bool) - st = strax.Context(storage=self.get_mock_sf(), check_available=("records",)) + st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs) st.set_context_config({"use_per_run_defaults": use_defaults}) return st diff --git a/tests/test_down_chunk_plugin.py b/tests/test_down_chunk_plugin.py new file mode 100644 index 000000000..f0d3d8a90 --- /dev/null +++ b/tests/test_down_chunk_plugin.py @@ -0,0 +1,69 @@ +from strax.testutils import RecordsWithTimeStructure, DownSampleRecords, run_id +import strax +import numpy as np + +import os +import tempfile +import shutil +import uuid +import unittest + + +class TestContext(unittest.TestCase): + """Tests for DownChunkPlugin class.""" + + def setUp(self): + """Make temp folder to write data to.""" + temp_folder = uuid.uuid4().hex + self.tempdir = os.path.join(tempfile.gettempdir(), temp_folder) + assert not os.path.exists(self.tempdir) + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_down_chunking(self): + st = self.get_context() + st.register(RecordsWithTimeStructure) + st.register(DownSampleRecords) + + st.make(run_id, "records") + st.make(run_id, "records_down_chunked") + + chunks_records = st.get_meta(run_id, "records")["chunks"] + chunks_records_down_chunked = st.get_meta(run_id, "records_down_chunked")["chunks"] + + _chunks_are_downsampled = len(chunks_records) * 2 == len(chunks_records_down_chunked) + assert _chunks_are_downsampled + + _chunks_are_continues = np.all([ + chunks_records_down_chunked[i]["end"] == chunks_records_down_chunked[i + 1]["start"] + for i in range(len(chunks_records_down_chunked) - 1) + ]) + assert _chunks_are_continues + + def test_down_chunking_multi_processing(self): + st = self.get_context(allow_multiprocess=True) + st.register(RecordsWithTimeStructure) + st.register(DownSampleRecords) + + st.make(run_id, "records", max_workers=1) + + class TestMultiProcessing(DownSampleRecords): + parallel = True + + st.register(TestMultiProcessing) + with self.assertRaises(NotImplementedError): + st.make(run_id, "records_down_chunked", max_workers=2) + + def get_context(self, **kwargs): + """Simple context to run tests.""" + st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs) + return st + + def get_mock_sf(self): + mock_rundb = [{"name": "0", strax.RUN_DEFAULTS_KEY: dict(base_area=43)}] + sf = strax.DataDirectory(path=self.tempdir, deep_scan=True, provide_run_metadata=True) + for d in mock_rundb: + sf.write_run_metadata(d["name"], d) + return sf