Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix processing algorithm edge cases #94

Merged
merged 3 commits into from
Sep 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 31 additions & 8 deletions strax/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from enum import IntEnum
import itertools
import logging
from functools import partial
import typing
import time
Expand Down Expand Up @@ -367,8 +368,8 @@ class OverlapWindowPlugin(Plugin):
a certain window on both sides.

Current implementation assumes:
- All inputs are sorted by endtime. Since everything in strax is sorted by
time, this only works for disjoint intervals such as peaks or events,
- All inputs are sorted by *endtime*. Since everything in strax is sorted
by time, this only works for disjoint intervals such as peaks or events,
but NOT records!
- You must read time info for your data kind, or create a new data kind.
"""
Expand All @@ -379,6 +380,8 @@ def __init__(self):
self.cached_input = {}
self.cached_results = None
self.last_threshold = -float('inf')
# This guy can have a logger, it's not parallelized anyway
self.log = logging.getLogger(self.__class__.__name__)

def get_window_size(self):
"""Return the required window size in nanoseconds"""
Expand All @@ -389,33 +392,53 @@ def iter(self, iters, executor=None):

# Yield results initially suppressed in fear of a next chunk
if self.cached_results is not None and len(self.cached_results):
self.log.debug(f"Last chunk! Sending out cached result "
f"{self.cached_results}")
yield self.cached_results
else:
self.log.debug("Last chunk! No cached results to send.")

def do_compute(self, chunk_i=None, **kwargs):
if not len(kwargs):
raise RuntimeError("OverlapWindowPlugin must have a dependency")
end = max([strax.endtime(x[-1])
for x in kwargs.values()])
invalid_beyond = end - self.get_window_size()
cache_inputs_beyond = end - 3 * self.get_window_size()
# Take slightly larger windows for safety: it is very easy for me
# (or the user) to have made an off-by-one error
# TODO: why do tests not fail is I set cache_inputs_beyond to
# end - window size - 2 ?
# (they do fail if I set to end - 0.5 * window size - 2)
invalid_beyond = end - self.get_window_size() - 1
cache_inputs_beyond = end - 2 * self.get_window_size() - 1
self.log.debug(f"Invalid beyond {invalid_beyond}, "
f"caching inputs beyond {cache_inputs_beyond}")

for k, v in kwargs.items():
if len(self.cached_input):
kwargs[k] = v = np.concatenate([self.cached_input[k], v])
self.cached_input[k] = v[strax.endtime(v) > cache_inputs_beyond]

self.log.debug(f"Cached input {self.cached_input}")

self.log.debug(f"Compute kwargs {kwargs}")
result = super().do_compute(chunk_i=chunk_i, **kwargs)

endtimes = strax.endtime(kwargs[self.data_kind]
if self.data_kind in kwargs
else result)
assert len(endtimes) == len(result)

# Remove results that are invalid or already sent out last time
is_valid = endtimes < invalid_beyond
self.cached_results = result[~is_valid]
result = result[is_valid
& (endtimes > self.last_threshold)]
not_sent_yet = endtimes >= self.last_threshold

# Cache all results we have not sent, nor are sending now
self.cached_results = result[not_sent_yet & (~is_valid)]

# Send out only valid results we haven't sent yet
result = result[is_valid & not_sent_yet]

self.log.debug(f"Cached results {self.cached_results}")
self.log.debug(f"Sending out result {result}")

self.last_threshold = invalid_beyond
return result
Expand Down
28 changes: 17 additions & 11 deletions strax/processing/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def find_break_i(x, safe_break, tolerant=True):
def fully_contained_in(things, containers):
"""Return array of len(things) with index of interval in containers
for which things are fully contained in a container, or -1 if no such
exists. We assume all intervals are sorted by time, and b_intervals
exists.
We assume all intervals are sorted by time, and b_intervals
nonoverlapping.
"""
result = np.ones(len(things), dtype=np.int32) * -1
Expand Down Expand Up @@ -132,20 +133,25 @@ def split_by_containment(things, containers):
if not len(containers):
return []

# Index of which container each thing belongs to, or -1
which_container = fully_contained_in(things, containers)

# Restrict to things in containers
mask = which_container != -1
things = things[mask]
which_container = which_container[mask]
things_split = np.split(
things,
np.where(np.diff(which_container))[0] + 1)

# Insert empties for containers with nothing
for c in np.setdiff1d(np.arange(len(containers)),
np.unique(which_container)):
if c == 0:
continue # np.split already produces an empty in this case?
things_split.insert(c, things[:0])
if not len(things):
# np.split has confusing behaviour for empty arrays
return [things[:0] for _ in range(len(containers))]

# Split things up by container
split_indices = np.where(np.diff(which_container))[0] + 1
things_split = np.split(things, split_indices)

# Insert empty arrays for empty containers
empty_containers = np.setdiff1d(np.arange(len(containers)),
np.unique(which_container))
for c_i in empty_containers:
things_split.insert(c_i, things[:0])

return things_split
16 changes: 15 additions & 1 deletion tests/test_general_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hypothesis import given
from hypothesis import given, example
from .helpers import sorted_intervals, disjoint_sorted_intervals
from .helpers import several_fake_records

Expand All @@ -7,6 +7,14 @@


@given(sorted_intervals, disjoint_sorted_intervals)
# Tricky example: uncontained interval precedes contained interval
# (this did not produce an issue, but good to show this is handled)
@example(things=np.array([(0, 1, 0, 1),
(0, 1, 1, 5),
(0, 1, 2, 1)],
dtype=strax.interval_dtype),
containers=np.array([(0, 1, 0, 4)],
dtype=strax.interval_dtype))
def test_fully_contained_in(things, containers):
result = strax.fully_contained_in(things, containers)

Expand All @@ -25,6 +33,12 @@ def test_fully_contained_in(things, containers):


@given(sorted_intervals, disjoint_sorted_intervals)
# Specific example to trigger issue #37
@example(
things=np.array([(0, 1, 2, 1)],
dtype=strax.interval_dtype),
containers=np.array([(0, 1, 0, 1), (0, 1, 2, 1)],
dtype=strax.interval_dtype))
def test_split_by_containment(things, containers):
result = strax.split_by_containment(things, containers)

Expand Down
33 changes: 23 additions & 10 deletions tests/test_overlap_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,37 @@

import numpy as np

from hypothesis import given, strategies
from hypothesis import given, strategies, example

import strax


@given(helpers.disjoint_sorted_intervals.filter(lambda x: len(x) > 0),
strategies.integers(min_value=0, max_value=3))
# Examples that trigger issue #49
@example(
input_peaks=np.array(
[(0, 1, 0, 1), (0, 1, 1, 10), (0, 1, 11, 1)],
dtype=strax.interval_dtype),
split_i=2)
@example(
input_peaks=np.array(
[(0, 1, 0, 1), (0, 1, 1, 1), (0, 1, 2, 9), (0, 1, 11, 1)],
dtype=strax.interval_dtype),
split_i=3)
# Other example that caused failures at some point
@example(
input_peaks=np.array(
[(0, 1, 0, 1), (0, 1, 7, 6), (0, 1, 13, 1)],
dtype=strax.interval_dtype),
split_i=2
)
def test_overlap_plugin(input_peaks, split_i):
"""Counting number of nearby peaks should not depend on the chunking"""
"""Counting the number of nearby peaks should not depend on how peaks are
chunked.
"""
chunks = np.split(input_peaks, [split_i])
chunks = [c for c in chunks if len(c)]
print("\n\nNew run")
print(strax.endtime(input_peaks),
[strax.endtime(c) for c in chunks],
[len(c) for c in chunks])

class Peaks(strax.Plugin):
depends_on = tuple()
Expand All @@ -25,9 +41,8 @@ class Peaks(strax.Plugin):
def compute(self, chunk_i):
return chunks[chunk_i]

# Hack to make peak output stop after a few chunks
def is_ready(self, chunk_i):
print(f"ready check for {chunk_i}, {len(chunks)},"
f" {chunk_i < len(chunks)}")
return chunk_i < len(chunks)

def source_finished(self):
Expand Down Expand Up @@ -55,10 +70,8 @@ def get_window_size(self):
return window

def compute(self, peaks):
print(f"Compute got {strax.endtime(peaks)}")
result = dict(
n_within_window=count_in_window(strax.endtime(peaks)))
print(f"Result is %s" % result)
return result

def iter(self, *args, **kwargs):
Expand Down