Skip to content

Commit

Permalink
Merge pull request #154 from JelleAalbers/faster_pulse_cutting
Browse files Browse the repository at this point in the history
Pulse processing upgrades
  • Loading branch information
JelleAalbers committed May 3, 2019
2 parents 897f705 + 1553b5c commit 76c87c8
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 42 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Streaming analysis for xenon experiments
[![Coverage Status](https://coveralls.io/repos/github/AxFoundation/strax/badge.svg?branch=master)](https://coveralls.io/github/AxFoundation/strax?branch=master)
[![PyPI version shields.io](https://img.shields.io/pypi/v/strax.svg)](https://pypi.python.org/pypi/strax/)
[![Join the chat at https://gitter.im/AxFoundation/strax](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/AxFoundation/strax?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/cc159474f2764d43b445d562a24ca245)](https://www.codacy.com/app/tunnell/strax?utm_source=github.com&utm_medium=referral&utm_content=AxFoundation/strax&utm_campaign=Badge_Grade)

Strax is an analysis framework for pulse-only digitization data, specialized for live data reduction at speeds of 50-100 MB(raw) / core / sec. For more information, please see the [strax documentation](https://strax.readthedocs.io).

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ psutil==5.4.8
numexpr==2.6.9
boto3==1.9.78
npshmex==0.1.2
scipy
pymongo
87 changes: 62 additions & 25 deletions strax/processing/data_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numba
from enum import IntEnum

from strax.processing.pulse_processing import NOT_APPLICABLE, record_links
from strax.processing.pulse_processing import NO_RECORD_LINK, record_links
from strax.processing.peak_building import find_peaks
from .general import fully_contained_in
from strax.dtypes import peak_dtype
Expand Down Expand Up @@ -49,10 +49,9 @@ def cut_baseline(records, n_before=48, n_after=30):
records.reduction_level[:] = ReductionLevel.BASELINE_CUT


@numba.jit(nopython=True, nogil=True, cache=True)
def cut_outside_hits(records, hits, left_extension=2, right_extension=15):
"""Zero record waveforms not within left_extension or right_extension of
hits.
"""Return records with waveforms zeroed if not within
left_extension or right_extension of hits.
These extensions properly account for breaking of pulses into records.
If you pass an incomplete (e.g. cut) set of records, we will not save
Expand All @@ -61,38 +60,76 @@ def cut_outside_hits(records, hits, left_extension=2, right_extension=15):
"""
if not len(records):
return
samples_per_record = len(records[0]['data'])

# For every sample, store if we can cut it or not
can_cut = np.ones((len(records), samples_per_record), dtype=np.bool_)
# Create a copy of records with blanked data
# Even a simple records.copy() is mightily slow in numba,
# and assignments to struct arrays seem troublesome.
# The obvious solution:
# new_recs = records.copy()
# new_recs['data'] = 0
# is quite slow.
# Replacing the last = with *= gives a factor 2 speed boost.
# But ~40% faster still is this:
meta_fields = [x for x in records.dtype.names
if x not in ['data', 'reduction_level']]

new_recs = np.zeros(len(records), dtype=records.dtype)
new_recs[meta_fields] = records[meta_fields]
new_recs['reduction_level'] = ReductionLevel.HITS_ONLY

_cut_outside_hits(records, hits, new_recs,
left_extension, right_extension)

return new_recs


@numba.jit(nopython=True, nogil=True, cache=True)
def _cut_outside_hits(records, hits, new_recs,
left_extension=2, right_extension=15):
if not len(records):
return
samples_per_record = len(records[0]['data'])

previous_record, next_record = record_links(records)

for hit_i in range(len(hits)):
h = hits[hit_i]
for hit_i, h in enumerate(hits):
rec_i = h['record_i']
r = records[rec_i]

# Keep required samples in current record
# Indices to keep, with 0 at the start of this record
start_keep = h['left'] - left_extension
end_keep = h['right'] + right_extension
can_cut[rec_i][max(0, start_keep):
min(end_keep, samples_per_record)] = 0

# Keep samples in previous/next record if applicable
# Never try to keep samples beyond the pulse
start_keep = max(
start_keep,
- samples_per_record * r['record_i'])
end_keep = min(
end_keep,
r['pulse_length'] - samples_per_record * r['record_i'])

# Indices of samples to keep in this record
a = max(0, start_keep)
b = min(end_keep, samples_per_record)
new_recs[rec_i]['data'][a:b] = records[rec_i]['data'][a:b]

# Keep samples in previous record, if there was one
if start_keep < 0:
prev_r = previous_record[rec_i]
if prev_r != NOT_APPLICABLE:
can_cut[prev_r][start_keep:] = 0
prev_ri = previous_record[rec_i]
if prev_ri != NO_RECORD_LINK:
# Note start_keep is negative, so this keeps the
# last few samples of the previous record
a_prev = start_keep
new_recs[prev_ri]['data'][a_prev:] = \
records[prev_ri]['data'][a_prev:]

# Same for the next record
if end_keep > samples_per_record:
next_r = next_record[rec_i]
if next_r != NOT_APPLICABLE:
can_cut[next_r][:end_keep - samples_per_record] = 0

# This is actually quite slow. Perhaps the [:] forces a copy?
# Without it, however, numba complains...
for i in range(len(can_cut)):
records[i]['data'][:] *= ~can_cut[i]
records['reduction_level'][:] = ReductionLevel.HITS_ONLY
next_ri = next_record[rec_i]
if next_ri != NO_RECORD_LINK:
b_next = end_keep - samples_per_record
new_recs[next_ri]['data'][:b_next] = \
records[next_ri]['data'][:b_next]


@numba.jit(nopython=True, nogil=True, cache=True)
Expand Down
17 changes: 11 additions & 6 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,17 @@ def sum_waveform(peaks, records, adc_to_pe, n_channels=248):
r_start = max(0, s)
r_end = min(n_r, s + n_p)
assert r_end > r_start

max_in_record = r['data'][r_start:r_end].max()
p['saturated_channel'][ch] = int(max_in_record >= r['baseline'])

# TODO Do we need .astype(np.int32).sum() ??
p['area_per_channel'][ch] += r['data'][r_start:r_end].sum()

bl_fpart = r['baseline'] % 1
p['area_per_channel'][ch] += (
r['data'][r_start:r_end].sum()
+ (int(round(
bl_fpart * (r_end - r_start)))))

# Range of peak that receives record
p_start = max(0, -s)
p_end = min(n_p, -s + n_r)
Expand All @@ -177,7 +181,7 @@ def sum_waveform(peaks, records, adc_to_pe, n_channels=248):

if p_end - p_start > 0:
swv_buffer[p_start:p_end] += \
r['data'][r_start:r_end] * adc_to_pe[ch]
(r['data'][r_start:r_end] + bl_fpart) * adc_to_pe[ch]

# Store the sum waveform
# Do we need to downsample the swv to store it?
Expand All @@ -192,6 +196,7 @@ def sum_waveform(peaks, records, adc_to_pe, n_channels=248):
p['data'][:p_length] = swv_buffer[:p_length]

# Store the total area and saturation count
p['area'] = (p['area_per_channel'][:n_channels] * adc_to_pe[:n_channels]).sum()
p['area'] = (p['area_per_channel'][:n_channels]
* adc_to_pe[:n_channels]).sum()
p['n_saturated_channels'] = p['saturated_channel'][:n_channels].sum()

133 changes: 124 additions & 9 deletions strax/processing/pulse_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
"""
import numpy as np
import numba
from scipy.ndimage import convolve1d

import strax
export, __all__ = strax.exporter()
__all__ += ['NO_RECORD_LINK']

# Constant for use in record_links, to indicate there is no prev/next record
NOT_APPLICABLE = -1
NO_RECORD_LINK = -1


@export
Expand All @@ -22,7 +24,7 @@ def baseline(records, baseline_samples=40):
baselining them!)
"""
if not len(records):
return
return records
samples_per_record = len(records[0]['data'])

# Array for looking up last baseline seen in channel
Expand All @@ -47,11 +49,35 @@ def baseline(records, baseline_samples=40):
d.baseline = bl


@export
@numba.jit(nopython=True, nogil=True, cache=True)
def zero_out_of_bounds(records):
""""Set waveforms to zero out of pulse bounds
"""
if not len(records):
return records
samples_per_record = len(records[0]['data'])

for r in records:
end = r['pulse_length'] - r['record_i'] * samples_per_record
if end < samples_per_record:
r['data'][end:] = 0


@export
@numba.jit(nopython=True, nogil=True, cache=True)
def integrate(records):
if not len(records):
return
samples_per_record = len(records[0]['data'])
for i, r in enumerate(records):
records[i]['area'] = r['data'].sum()
n_real_samples = min(
samples_per_record,
r['pulse_length'] - r['record_i'] * samples_per_record)
records[i]['area'] = (
r['data'].sum()
+ int(round(r['baseline'] % 1)) * n_real_samples)



@export
Expand All @@ -65,11 +91,11 @@ def record_links(records):
return
n_channels = records['channel'].max() + 1
samples_per_record = len(records[0]['data'])
previous_record = np.ones(len(records), dtype=np.int32) * NOT_APPLICABLE
next_record = np.ones(len(records), dtype=np.int32) * NOT_APPLICABLE
previous_record = np.ones(len(records), dtype=np.int32) * NO_RECORD_LINK
next_record = np.ones(len(records), dtype=np.int32) * NO_RECORD_LINK

# What was the index of the last record seen in each channel?
last_record_seen = np.ones(n_channels, dtype=np.int32) * NOT_APPLICABLE
last_record_seen = np.ones(n_channels, dtype=np.int32) * NO_RECORD_LINK
# What would the start time be of a record that continues that record?
expected_next_start = np.zeros(n_channels, dtype=np.int64)

Expand All @@ -89,7 +115,7 @@ def record_links(records):

if r['record_i'] == 0:
# Record starts a new pulse
previous_record[i] = NOT_APPLICABLE
previous_record[i] = NO_RECORD_LINK

elif r['time'] == expected_next_start[ch]:
# Continuing record.
Expand Down Expand Up @@ -127,11 +153,12 @@ def find_hits(records, threshold=15, _result_buffer=None):
# print("Starting record ', record_i)
in_interval = False
hit_start = -1
area = 0

for i in range(samples_per_record):
for i, x in enumerate(r['data']):
# We can't use enumerate over r['data'], numba gives error
# TODO: file issue?
above_threshold = r['data'][i] > threshold
above_threshold = x > threshold
# print(r['data'][i], above_threshold, in_interval, hit_start)

if not in_interval and above_threshold:
Expand All @@ -149,8 +176,12 @@ def find_hits(records, threshold=15, _result_buffer=None):
# Hit ends at the *end* of this sample
# (because the record ends)
hit_end = i + 1
area += x
in_interval = False

else:
area += x

if not in_interval:
# print('saving hit')
# Hit is done, add it to the result
Expand All @@ -167,6 +198,10 @@ def find_hits(records, threshold=15, _result_buffer=None):
res['dt'] = r['dt']
res['channel'] = r['channel']
res['record_i'] = record_i
area += int(round(
res['length'] * (r['baseline'] % 1)))
res['area'] = area
area = 0

# Yield buffer to caller if needed
offset += 1
Expand All @@ -178,3 +213,83 @@ def find_hits(records, threshold=15, _result_buffer=None):
# hit_start = 0
# hit_end = 0
yield offset


def filter_records(r, ir):
"""Apply filter with impulse response ir over the records r.
Assumes the filter origin is at the impulse response maximum.
:param ws: Waveform matrix, must be float
:param ir: Impulse response, must have odd length. Will normalize.
:param prev_r: Previous record map from strax.record_links
:param next_r: Next record map from strax.record_links
"""
# Convert waveforms to float and restore baseline
ws = r['data'].astype(np.float) + (r['baseline'] % 1)[:, np.newaxis]

prev_r, next_r = strax.record_links(r)
ws_filtered = filter_waveforms(
ws,
ir / ir.sum(),
prev_r, next_r)

# Restore waveforms as integers
r['data'] = ws_filtered.astype(np.int16)


@export
def filter_waveforms(ws, ir, prev_r, next_r):
"""Convolve filter with impulse response ir over each row of ws.
Assumes the filter origin is at the impulse response maximum.
:param ws: Waveform matrix, must be float
:param ir: Impulse response, must have odd length.
:param prev_r: Previous record map from strax.record_links
:param next_r: Next record map from strax.record_links
"""
n = len(ir)
a = n//2
if n % 2 == 0:
raise ValueError("Impulse response must have odd length")

# Do the convolutions outside numba;
# numba supports np.convolve, but this seems to be quite slow

# Main convolution
maxi = np.argmax(ir)
result = convolve1d(ws,
ir,
origin=maxi - a,
mode='constant')

# Contribution to next record (if present)
have_next = ws[next_r != -1]
to_next = convolve1d(have_next[:, -(n - maxi - 1):],
ir,
origin=a,
mode='constant')

# Contribution to previous record (if present)
have_prev = ws[prev_r != -1]
to_prev = convolve1d(have_prev[:, :maxi],
ir,
origin=-a,
mode='constant')

# Combine the results in numba; here numba is much faster (~100x?)
# than a numpy assignment using boolean array instead of a for loop.
_combine_filter_results(result, to_next, to_prev, next_r, prev_r, maxi, n)
return result


@numba.jit(nopython=True, cache=True, nogil=True)
def _combine_filter_results(result, to_next, to_prev, next_r, prev_r, maxi, n):
seen_that_have_next = 0
seen_that_have_prev = 0
for i in range(len(result)):
if next_r[i] != NO_RECORD_LINK:
result[next_r[i], :n - maxi - 1] += to_next[seen_that_have_next]
seen_that_have_next += 1
if prev_r[i] != NO_RECORD_LINK:
result[prev_r[i], -maxi:] += to_prev[seen_that_have_prev]
seen_that_have_prev += 1

0 comments on commit 76c87c8

Please sign in to comment.