Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunk yielding plugin and tests #769

Merged
merged 9 commits into from Nov 23, 2023
1 change: 1 addition & 0 deletions strax/plugins/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .merge_only_plugin import *
from .overlap_window_plugin import *
from .parrallel_source_plugin import *
from .down_chunking_plugin import *
40 changes: 40 additions & 0 deletions strax/plugins/down_chunking_plugin.py
@@ -0,0 +1,40 @@
import strax
import types
import inspect
from .plugin import Plugin
export, __all__ = strax.exporter()


##
# Plugin which allows to use yield in plugins compute method.
# Allows to chunk down output before storing to disk.
# Only works if multiprocessing is omitted.
##

@export
class DownChunkingPlugin(Plugin):
"""Plugin that merges data from its dependencies
"""

def iter(self, iters, executor=None):

_plugin_uses_multi_threading = (self.parallel
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
and executor is not None
and (inspect.isgeneratorfunction(self.compute))
)
if _plugin_uses_multi_threading:
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" uses an iterator as compute method. '
'This is not supported in multi-threading/processing.')
return super().iter(iters, executor=None)

def _iter_return(self, chunk_i, **inputs_merged):
return self.do_compute(chunk_i=chunk_i, **inputs_merged)

def _fix_output(self, result, start, end, _dtype=None):
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper around _fix_output to support the return
of iterators.
"""
if isinstance(result, types.GeneratorType):
return result
return super()._fix_output(result, start, end)
6 changes: 5 additions & 1 deletion strax/plugins/plugin.py
Expand Up @@ -496,7 +496,7 @@ class IterDone(Exception):
pending_futures = [f for f in pending_futures if not f.done()]
yield new_future
else:
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)
yield from self._iter_return(chunk_i=chunk_i, **inputs_merged)

except IterDone:
# Check all sources are exhausted.
Expand All @@ -523,6 +523,10 @@ class IterDone(Exception):
finally:
self.cleanup(wait_for=pending_futures)

def _iter_return(self, chunk_i, **inputs_merged):
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)


def cleanup(self, wait_for):
pass
# A standard plugin doesn't need to do anything here
Expand Down
68 changes: 68 additions & 0 deletions strax/testutils.py
Expand Up @@ -260,6 +260,74 @@ def compute(self, peaks):
return dict(peak_classification=p,
lone_hits=lh)


# Plugins with time structure within chunks,
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
# used to test down chunking within plugin compute.
@strax.takes_config(
strax.Option('n_chunks', type=int, default=10, track=False),
strax.Option('recs_per_chunk', type=int, default=10, track=False),
)
class RecordsWithTimeStructure(strax.Plugin):
provides = 'records'
parallel = 'process'
depends_on = tuple()
dtype = strax.record_dtype()

rechunk_on_save = False

def source_finished(self):
return True

def is_ready(self, chunk_i):
return chunk_i < self.config['n_chunks']

def setup(self):
self.last_end = 0

def compute(self, chunk_i):

r = np.zeros(self.config['recs_per_chunk'], self.dtype)
r['time'] = self.last_end + np.arange(self.config['recs_per_chunk']) + 5
r['length'] = r['dt'] = 1
r['channel'] = np.arange(len(r))

end = self.last_end + self.config['recs_per_chunk'] + 10
chunk = self.chunk(start=self.last_end, end=end, data=r)
self.last_end = end

return chunk


class DownSampleRecords(strax.DownChunkingPlugin):
"""PLugin to test the downsampling of Chunks during compute. Needed
for simulations.
"""

provides = 'records_down_chunked'
depends_on = 'records'
dtype = strax.record_dtype()
rechunk_on_save = False
parallel='process'

def compute(self, records, start, end):
offset = 0
last_start = start

count=0
for count, r in enumerate(records):
if count == 5:
res = records[offset:count]
chunk_end = np.max(strax.endtime(res))
offset = count
chunk = self.chunk(start=last_start, end=chunk_end, data=res)
last_start = chunk_end
yield chunk

res = records[offset:count+1]
chunk = self.chunk(start=last_start, end=end, data=res)
yield chunk


# Used in test_core.py
run_id = '0'

Expand Down
35 changes: 32 additions & 3 deletions tests/test_context.py
@@ -1,5 +1,5 @@
import strax
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, RecordsWithTimeStructure, DownSampleRecords, run_id
import tempfile
import numpy as np
from hypothesis import given, settings
Expand Down Expand Up @@ -215,6 +215,34 @@ def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)

def test_down_chunking(self):
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
st = self.get_context(False)
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, 'records')
st.make(run_id, 'records_down_chunked')

chunks_records = st.get_meta(run_id, 'records')['chunks']
chunks_records_down_chunked = st.get_meta(run_id, 'records_down_chunked')['chunks']

_chunks_are_downsampled = len(chunks_records)*2 == len(chunks_records_down_chunked)
assert _chunks_are_downsampled

_chunks_are_continues = np.all([chunks_records_down_chunked[i]['end'] == chunks_records_down_chunked[i+1]['start'] for i in range(len(chunks_records_down_chunked)-1)])
assert _chunks_are_continues

def test_down_chunking_multi_processing(self):
st = self.get_context(False, allow_multiprocess=True)
st.set_context_config({'use_per_run_defaults': False})
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, 'records', max_workers=1)
with self.assertRaises(NotImplementedError):
st.make(run_id, 'records_down_chunked', max_workers=2)


def test_get_plugins_with_cache(self):
st = self.get_context(False)
st.register(Records)
Expand Down Expand Up @@ -283,11 +311,12 @@ def test_deregister(self):
st.deregister_plugins_with_missing_dependencies()
assert st._plugin_class_registry.pop('peaks', None) is None

def get_context(self, use_defaults):
def get_context(self, use_defaults, **kwargs):
"""Get simple context where we have one mock run in the only storage frontend"""
assert isinstance(use_defaults, bool)
st = strax.Context(storage=self.get_mock_sf(),
check_available=('records',)
check_available=('records',),
**kwargs
)
st.set_context_config({'use_per_run_defaults': use_defaults})
return st
Expand Down