Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunk yielding plugin and tests #769

Merged
merged 9 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions strax/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .merge_only_plugin import *
from .overlap_window_plugin import *
from .parrallel_source_plugin import *
from .down_chunking_plugin import *
44 changes: 44 additions & 0 deletions strax/plugins/down_chunking_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import strax
from .plugin import Plugin

export, __all__ = strax.exporter()


##
# Plugin which allows to use yield in plugins compute method.
# Allows to chunk down output before storing to disk.
# Only works if multiprocessing is omitted.
##


@export
class DownChunkingPlugin(Plugin):
"""Plugin that merges data from its dependencies."""

parallel = False

def __init__(self):
super().__init__()

if self.parallel:
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which '
"currently does not support parallel processing."
)

if self.multi_output:
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which '
"currently does not support multiple outputs. Please only provide "
"a single data-type."
)

def iter(self, iters, executor=None):
return super().iter(iters, executor)

def _iter_compute(self, chunk_i, **inputs_merged):
return self.do_compute(chunk_i=chunk_i, **inputs_merged)

def _fix_output(self, result, start, end, _dtype=None):
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper around _fix_output to support the return of iterators."""
return result
8 changes: 7 additions & 1 deletion strax/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(self):
# not have to updated save_when
self.save_when = immutabledict.fromkeys(self.provides, self.save_when)

if getattr(self, "provides", None):
self.provides = strax.to_str_tuple(self.provides)
self.compute_pars = compute_pars
self.input_buffer = dict()

Expand Down Expand Up @@ -492,7 +494,7 @@ class IterDone(Exception):
pending_futures = [f for f in pending_futures if not f.done()]
yield new_future
else:
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)
yield from self._iter_compute(chunk_i=chunk_i, **inputs_merged)

except IterDone:
# Check all sources are exhausted.
Expand All @@ -517,6 +519,10 @@ class IterDone(Exception):
finally:
self.cleanup(wait_for=pending_futures)

def _iter_compute(self, chunk_i, **inputs_merged):
"""Either yields or returns strax chunks from the input."""
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)

def cleanup(self, wait_for):
pass
# A standard plugin doesn't need to do anything here
Expand Down
52 changes: 52 additions & 0 deletions strax/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,58 @@ def compute(self, peaks):
return dict(peak_classification=p, lone_hits=lh)


# Plugins with time structure within chunks,
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
# used to test down chunking within plugin compute.
class RecordsWithTimeStructure(Records):
"""Same as Records but with some structure in "time" for testing."""

def setup(self):
self.last_end = 0

def compute(self, chunk_i):
r = np.zeros(self.config["recs_per_chunk"], self.dtype)
r["time"] = self.last_end + np.arange(self.config["recs_per_chunk"]) + 5
r["length"] = r["dt"] = 1
r["channel"] = np.arange(len(r))

end = self.last_end + self.config["recs_per_chunk"] + 10
chunk = self.chunk(start=self.last_end, end=end, data=r)
self.last_end = end

return chunk


class DownSampleRecords(strax.DownChunkingPlugin):
"""PLugin to test the downsampling of Chunks during compute.

Needed for simulations.

"""

provides = "records_down_chunked"
depends_on = "records"
dtype = strax.record_dtype()
rechunk_on_save = False

def compute(self, records, start, end):
offset = 0
last_start = start

count = 0
for count, r in enumerate(records):
if count == 5:
res = records[offset:count]
chunk_end = np.max(strax.endtime(res))
offset = count
chunk = self.chunk(start=last_start, end=chunk_end, data=res)
last_start = chunk_end
yield chunk

res = records[offset : count + 1]
chunk = self.chunk(start=last_start, end=end, data=res)
yield chunk


# Used in test_core.py
run_id = "0"

Expand Down
12 changes: 9 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import strax
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id
from strax.testutils import (
Records,
Peaks,
PeaksWoPerRunDefault,
PeakClassification,
run_id,
)
import tempfile
import numpy as np
from hypothesis import given, settings
Expand Down Expand Up @@ -301,10 +307,10 @@ def test_deregister(self):
st.deregister_plugins_with_missing_dependencies()
assert st._plugin_class_registry.pop("peaks", None) is None

def get_context(self, use_defaults):
def get_context(self, use_defaults, **kwargs):
"""Get simple context where we have one mock run in the only storage frontend."""
assert isinstance(use_defaults, bool)
st = strax.Context(storage=self.get_mock_sf(), check_available=("records",))
st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs)
st.set_context_config({"use_per_run_defaults": use_defaults})
return st

Expand Down
69 changes: 69 additions & 0 deletions tests/test_down_chunk_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from strax.testutils import RecordsWithTimeStructure, DownSampleRecords, run_id
import strax
import numpy as np

import os
import tempfile
import shutil
import uuid
import unittest


class TestContext(unittest.TestCase):
"""Tests for DownChunkPlugin class."""

def setUp(self):
"""Make temp folder to write data to."""
temp_folder = uuid.uuid4().hex
self.tempdir = os.path.join(tempfile.gettempdir(), temp_folder)
assert not os.path.exists(self.tempdir)

def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)

def test_down_chunking(self):
st = self.get_context()
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, "records")
st.make(run_id, "records_down_chunked")

chunks_records = st.get_meta(run_id, "records")["chunks"]
chunks_records_down_chunked = st.get_meta(run_id, "records_down_chunked")["chunks"]

_chunks_are_downsampled = len(chunks_records) * 2 == len(chunks_records_down_chunked)
assert _chunks_are_downsampled

_chunks_are_continues = np.all([
chunks_records_down_chunked[i]["end"] == chunks_records_down_chunked[i + 1]["start"]
for i in range(len(chunks_records_down_chunked) - 1)
])
assert _chunks_are_continues

def test_down_chunking_multi_processing(self):
st = self.get_context(allow_multiprocess=True)
st.register(RecordsWithTimeStructure)
st.register(DownSampleRecords)

st.make(run_id, "records", max_workers=1)

class TestMultiProcessing(DownSampleRecords):
parallel = True

st.register(TestMultiProcessing)
with self.assertRaises(NotImplementedError):
st.make(run_id, "records_down_chunked", max_workers=2)

def get_context(self, **kwargs):
"""Simple context to run tests."""
st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs)
return st

def get_mock_sf(self):
mock_rundb = [{"name": "0", strax.RUN_DEFAULTS_KEY: dict(base_area=43)}]
sf = strax.DataDirectory(path=self.tempdir, deep_scan=True, provide_run_metadata=True)
for d in mock_rundb:
sf.write_run_metadata(d["name"], d)
return sf