Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Som plugin #1269

Merged
merged 28 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
99d9165
som peaklet classification
LuisSanchez25 Jul 31, 2023
9bf35fb
Updating plugin, this needs straxen PR Waveform attributes #745
Aug 4, 2023
280bdc4
update plugin, this works now
Aug 7, 2023
4949ef6
remove som type and propagate som classification up to peaks
Aug 10, 2023
57bd953
do not classify type 0
LuisSanchez25 Aug 14, 2023
ed8cfd5
Fix issue with array sisize
LuisSanchez25 Aug 14, 2023
303e0c7
add som_type
LuisSanchez25 Aug 16, 2023
b8c80c7
Update peaklet_classification_som.py
LuisSanchez25 Aug 25, 2023
a19ff24
fix som_type issue
LuisSanchez25 Aug 28, 2023
78afb2e
update model + some function details
LuisSanchez25 Aug 30, 2023
202ec92
new file + new type 3 added
LuisSanchez25 Sep 8, 2023
a8dec9b
Merge branch 'master' into SOM_plugin
LuisSanchez25 Sep 15, 2023
e650cbb
Merge branch 'master' into SOM_plugin
LuisSanchez25 Sep 16, 2023
2ad4902
separate som_plugin from main pipeline
LuisSanchez25 Sep 24, 2023
019a332
Merge branch 'master' into SOM_plugin
LuisSanchez25 Sep 24, 2023
2802303
Add tests to produce input vectors
LuisSanchez25 Sep 24, 2023
e59228a
Merge branch 'SOM_plugin' of https://github.com/XENONnT/straxen into …
LuisSanchez25 Sep 25, 2023
ca7fccc
current status of plugin, detach plugin before we are ready fro review
LuisSanchez25 Sep 29, 2023
1ff863c
start making new context
LuisSanchez25 Oct 3, 2023
4c0aaab
fix plugin and context
LuisSanchez25 Oct 4, 2023
e563201
Merge branch 'master' into SOM_plugin
LuisSanchez25 Oct 4, 2023
162dae4
clean up plugin
LuisSanchez25 Oct 4, 2023
4dbf283
Merge branch 'master' into SOM_plugin
LuisSanchez25 Oct 4, 2023
2c8618f
Update peaklet_classification_som.py
LuisSanchez25 Oct 4, 2023
2c67dd6
conform to pep8 style
LuisSanchez25 Oct 6, 2023
9607781
Merge branch 'master' into SOM_plugin
LuisSanchez25 Oct 6, 2023
66e392a
Update peaklet_classification_som.py
LuisSanchez25 Oct 6, 2023
d5cf03f
Update peaklet_classification_som.py
LuisSanchez25 Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion straxen/plugins/merged_s2s/merged_s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MergedS2s(strax.OverlapWindowPlugin):
"""
__version__ = '1.0.2'

depends_on = ('peaklets', 'peaklet_classification', 'lone_hits')
depends_on = ('peaklets', 'peaklet_classification_som', 'lone_hits')
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
data_kind = 'merged_s2s'
provides = 'merged_s2s'

Expand Down
3 changes: 3 additions & 0 deletions straxen/plugins/peaklets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

from . import peaklet_classification
from .peaklet_classification import *

from . import peaklet_classification_som
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
from .peaklet_classification_som import *
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
254 changes: 254 additions & 0 deletions straxen/plugins/peaklets/peaklet_classification_som.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import numpy as np
import numpy.lib.recfunctions as rfn
LuisSanchez25 marked this conversation as resolved.
Show resolved Hide resolved
from scipy.spatial.distance import cdist
import numba

import strax
import straxen

export, __all__ = strax.exporter()


@export
class PeakletClassificationSOM(strax.Plugin):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS230 Found too many public instance attributes: 7 > 6

"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D204 1 blank line required after class docstring

Self-Organizing Maps (SOM)
https://xe1t-wiki.lngs.infn.it/doku.php?id=xenon:xenonnt:lsanchez:unsupervised_neural_network_som_methods
For peaklet classification. We this pluggin will provide 2 data types, the 'type' we are
already familiar with, classifying peaklets as s1, s2 (using the new classification) or
unknown (from the previous classification). As well as a new data type, SOM type, which
will be assigned numbers based on the cluster in the SOM in which they are found. For
each version I will make some documentation in the corrections repository explaining
what I believe each cluster represents.

This correction/plugin is currently on the testing phase, feel free to use it if you are
curious or just want to test it or try it out but note this is note ready to be used in
analysis.
"""
__version__ = '0.0.1'

#rechunk_on_save = immutabledict(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E800 Found commented out code

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E265 block comment should start with '# '

# peaklet_classification_som=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E800 Found commented out code

# som_peaklet_data=True)

depends_on = ('peaklets', 'peaklet_classification')

provides = ('peaklet_classification', 'som_peaklet_data')
#provides = ('peaklet_classification_som')
data_kind = {k: k for k in provides}
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved

# parallel = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E800 Found commented out code


som_files = straxen.URLConfig(default='resource://xedocs://som_classifiers?attr=value&version=v1&run_id=045000&fmt=npy')

#dtype = (strax.peak_interval_dtype
# + [('type', np.int8, 'Classification of the peak(let)')])


def infer_dtype(self):
LuisSanchez25 marked this conversation as resolved.
Show resolved Hide resolved
dtype = dict()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
C408 Unnecessary dict call - rewrite as a literal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtype = dict()
dtype = {}

dtype['peaklet_classification_som'] = (strax.peak_interval_dtype +

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W504 line break after binary operator

[('type', np.int8, 'Classification of the peak(let)')])
dtype['som_peaklet_data'] = (strax.peak_interval_dtype + [('som_type', np.int8, 'SOM type of the peak(let)')]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS221 Found line with high Jones Complexity: 15 > 14

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (117 > 100 characters)

+ [('loc_x_som', np.int16, 'x location of the peak(let) in the SOM')]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W503 line break before binary operator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (106 > 100 characters)

+ [('loc_y_som', np.int16, 'y location of the peak(let) in the SOM')])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W503 line break before binary operator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (107 > 100 characters)


return dtype


def setup(self):
LuisSanchez25 marked this conversation as resolved.
Show resolved Hide resolved
self.som_weight_cube = self.som_files['weight_cube']
self.som_img = self.som_files['som_img']
self.som_norm_factors = self.som_files['norm_factors']
self.som_s1_array = self.som_files['s1_array']
self.som_s2_array = self.som_files['s2_array']
self.som_s3_array = self.som_files['s3_array']
self.som_s0_array = self.som_files['s0_array']

def compute(self, peaklets):
LuisSanchez25 marked this conversation as resolved.
Show resolved Hide resolved
peaklets_w_type = peaklets.copy()
mask_non_zero = peaklets_w_type['type'] != 0
peaklets_w_type = peaklets_w_type[mask_non_zero]
#result = np.zeros(len(peaklets), dtype=self.dtype)
result = np.zeros(len(peaklets), dtype=self.dtype['peaklet_classification_som'])
som_info = np.zeros(len(peaklets), dtype=self.dtype['som_peaklet_data'])
som_type, x_som, y_som = recall_populations(peaklets_w_type, self.som_weight_cube,
self.som_img,
self.som_norm_factors)

som_info['time'] = peaklets['time']
som_info['length'] = peaklets['length']
som_info['dt'] = peaklets['dt']
som_info['som_type'][mask_non_zero] = som_type
som_info['loc_x_som'][mask_non_zero] = x_som
som_info['loc_y_som'][mask_non_zero] = y_som

LuisSanchez25 marked this conversation as resolved.
Show resolved Hide resolved
strax_type = som_type_to_type(som_type,
self.som_s1_array,
self.som_s2_array,
self.som_s3_array,
self.som_s0_array)
result['time'] = peaklets['time']
result['length'] = peaklets['length']
result['dt'] = peaklets['dt']
result['type'][mask_non_zero] = strax_type
#result['som_type'][mask_non_zero] = som_type + 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E800 Found commented out code

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E265 block comment should start with '# '

return dict(peaklet_classification_som=result, som_peaklet_data=som_info)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
C408 Unnecessary dict call - rewrite as a literal.

#return result


def recall_populations(dataset, weight_cube, SOM_cls_img, norm_factors):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
N803 argument name 'SOM_cls_img' should be lowercase

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix this

"""
Master function that should let the user provide a weightcube,
a reference img as a np.array, a dataset and a set of normalization factors.
In theory, if these 5 things are provided, this function should output
the original data back with one added field with the name "SOM_type"
weight_cube: SOM weight cube (3D array)
SOM_cls_img: SOM reference image as a numpy array
dataset: Data to preform the recall on (Should be peaklet level data)
normfactos: A set of 11 numbers to normalize the data so we can preform a recall
"""
[SOM_xdim, SOM_ydim, SOM_zdim] = weight_cube.shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS359 Found an iterable unpacking to list

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[SOM_xdim, SOM_ydim, SOM_zdim] = weight_cube.shape
xdim, ydim, zdim = weight_cube.shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
N806 variable 'SOM_xdim' in function should be lowercase

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove all these CamelCase variable names. Use snake_case for variable names.

[IMG_xdim, IMG_ydim, IMG_zdim] = SOM_cls_img.shape
unique_colors = np.unique(np.reshape(SOM_cls_img, [SOM_xdim * SOM_ydim, 3]), axis=0)
# Checks that the reference image matches the weight cube
assert SOM_xdim == IMG_xdim, f'Dimensions mismatch between SOM weight cube ({SOM_xdim}) and reference image ({IMG_xdim})'
assert SOM_ydim == IMG_ydim, f'Dimensions mismatch between SOM weight cube ({SOM_ydim}) and reference image ({IMG_ydim})'

assert all(dataset['type'] != 0), 'Dataset contains unclassified peaklets'
# Get the deciles representation of data for recall
decile_transform_check = data_to_log_decile_log_area_aft(dataset, norm_factors)
# preform a recall of the dataset with the weight cube
# assign each population color a number (can do from previous function)
ref_map = generate_color_ref_map(SOM_cls_img, unique_colors, SOM_xdim, SOM_ydim)
SOM_cls_array = np.empty(len(dataset['area']))
SOM_cls_array[:] = np.nan
# Make new numpy structured array to save the SOM cls data
data_with_SOM_cls = rfn.append_fields(dataset, 'SOM_type', SOM_cls_array)
# preforms the recall and assigns SOM_type label
output_data, x_som, y_som = SOM_cls_recall(data_with_SOM_cls, decile_transform_check, weight_cube, ref_map)
return output_data['SOM_type'], x_som, y_som


def generate_color_ref_map(color_image, unique_colors, xdim, ydim):
ref_map = np.zeros((xdim, ydim))
for color in np.arange(len(unique_colors)):
mask = np.all(np.equal(color_image, unique_colors[color, :]), axis=2)
indices = np.argwhere(mask) # generates a 2d mask
for loc in np.arange(len(indices)):
ref_map[indices[loc][0], indices[loc][1]] = color
return ref_map


def SOM_cls_recall(array_to_fill, data_in_SOM_fmt, weight_cube, reference_map):
[SOM_xdim, SOM_ydim, _] = weight_cube.shape
# for data_point in data_in_SOM_fmt:
distances = cdist(weight_cube.reshape(-1, weight_cube.shape[-1]), data_in_SOM_fmt, metric='euclidean')
w_neuron = np.argmin(distances, axis=0)
x_idx, y_idx = np.unravel_index(w_neuron, (SOM_xdim, SOM_ydim))
array_to_fill['SOM_type'] = reference_map[x_idx, y_idx]
return array_to_fill, x_idx, y_idx


def som_type_to_type(som_type, s1_array, s2_array, s3_array, s0_array):
"""
Converts the SOM type into either S1 or S2 type (1, 2)
som_type: array with integers corresponding to the different SOM types
s1_array: array containing the number corresponding to the SOM types which should
be converted to S1's
"""
som_type_copy = som_type.copy()
som_type_copy[np.isin(som_type_copy, s1_array)] = 1234
som_type_copy[np.isin(som_type_copy, s2_array)] = 5678
som_type_copy[np.isin(som_type_copy, s3_array)] = -5
som_type_copy[np.isin(som_type_copy, s0_array)] = -250
som_type_copy[som_type_copy == 1234] = 1
som_type_copy[som_type_copy == 5678] = 2
som_type_copy[som_type_copy == -5] = 3
som_type_copy[som_type_copy == -250] = 0
#assert np.all(np.unique(som_type_copy) == np.array([0, 1, 2])), f'Error, values other than s1 and s2 found in the array'
return som_type_copy


# Need function to convert things to S1s and S2s
def data_to_log_decile_log_area_aft(peaklet_data, normalization_factor):
"""
Converts peaklet data into the current best inputs for the SOM,
log10(deciles) + log10(area) + AFT
Since we are dealing with logs, anything less than 1 will be set to 1
"""
# turn deciles into approriate 'normalized' format (maybe also consider L1 normalization of these inputs)
decile_data = compute_quantiles(peaklet_data, 10)
data = peaklet_data.copy()
decile_data[decile_data < 1] = 1
# decile_L1 = np.log10(decile_data)
decile_log = np.log10(decile_data)
decile_log_over_max = np.divide(decile_log, normalization_factor[:10])
# Now lets deal with area
data['area'] = data['area'] + normalization_factor[11] + 1
peaklet_log_area = np.log10(data['area'])
peaklet_aft = np.sum(data['area_per_channel'][:, :straxen.n_top_pmts], axis=1) / peaklet_data['area']
peaklet_aft = np.where(peaklet_aft > 0, peaklet_aft, 0)
peaklet_aft = np.where(peaklet_aft < 1, peaklet_aft, 1)
deciles_area_aft = np.concatenate((decile_log_over_max,
np.reshape(peaklet_log_area, (len(peaklet_log_area), 1)) / normalization_factor[
10],
np.reshape(peaklet_aft, (len(peaklet_log_area), 1))), axis=1)
return deciles_area_aft


def compute_quantiles(peaks: np.ndarray, n_samples: int):
"""
Compute waveforms and quantiles for a given number of nodes(attributes)
:param peaks:
:param n_samples: number of nodes or attributes
:return:quantiles
"""
data = peaks['data'].copy()
data[data < 0.0] = 0.0
dt = peaks['dt']
q = compute_wf_attributes(data, dt, n_samples)
return q


@export
@numba.jit(nopython=True, cache=True)
def compute_wf_attributes(data, sample_length, n_samples: int):
"""
Compute waveform attribures
Quantiles: represent the amount of time elapsed for
a given fraction of the total waveform area to be observed in n_samples
i.e. n_samples = 10, then quantiles are equivalent deciles
Waveforms: downsampled waveform to n_samples
:param data: waveform e.g. peaks or peaklets
:param n_samples: compute quantiles for a given number of samples
:return: waveforms and quantiles of size n_samples
"""
assert data.shape[0] == len(sample_length), "ararys must have same size"

num_samples = data.shape[1]

quantiles = np.zeros((len(data), n_samples), dtype=np.float64)

# Cannot compute with with more samples than actual waveform sample
assert num_samples > n_samples, "cannot compute with more samples than the actual waveform"
assert num_samples % n_samples == 0, "number of samples must be a multiple of n_samples"

# Compute quantiles
inter_points = np.linspace(0., 1. - (1. / n_samples), n_samples)
cumsum_steps = np.zeros(n_samples + 1, dtype=np.float64)
frac_of_cumsum = np.zeros(num_samples + 1)
sample_number_div_dt = np.arange(0, num_samples + 1, 1)
for i, (samples, dt) in enumerate(zip(data, sample_length)):
if np.sum(samples) == 0:
continue
# reset buffers
frac_of_cumsum[:] = 0
cumsum_steps[:] = 0
frac_of_cumsum[1:] = np.cumsum(samples)
frac_of_cumsum[1:] = frac_of_cumsum[1:] / frac_of_cumsum[-1]
cumsum_steps[:-1] = np.interp(inter_points, frac_of_cumsum, sample_number_div_dt * dt)
cumsum_steps[-1] = sample_number_div_dt[-1] * dt
quantiles[i] = cumsum_steps[1:] - cumsum_steps[:-1]

return quantiles
2 changes: 1 addition & 1 deletion straxen/plugins/peaks/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Peaks(strax.Plugin):
"""
__version__ = '0.1.2'

depends_on = ('peaklets', 'peaklet_classification', 'merged_s2s')
depends_on = ('peaklets', 'peaklet_classification_som', 'merged_s2s')
data_kind = 'peaks'
provides = 'peaks'
parallel = True
Expand Down
61 changes: 60 additions & 1 deletion tests/test_peaklet_processing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,43 @@
import numpy as np
from hypothesis import given, settings
from hypothesis import given, strategies, example, settings
import hypothesis.strategies as strat
import strax
from strax.testutils import fake_hits
import straxen
from straxen.plugins.peaklets.peaklets import get_tight_coin
from straxen.plugins.peaklets.peaklet_classification_som import compute_wf_attributes


def get_filled_peaks(peak_length, data_length, n_widths):
dtype = [(('Start time since unix epoch [ns]', 'time'), np.int64),
(('dt in ns', 'dt'), np.int64),
(('length of p', 'length'), np.int16),
(('area of p', 'area'), np.float64),
(('data of p', 'data'), (np.float64, data_length)),
]
if n_widths is not None:
dtype += [
(('width of p', 'width'),
(np.float64, n_widths)),
(('area_decile_from_midpoint of p', 'area_decile_from_midpoint'),
(np.float64, n_widths)),
]
peaks = np.zeros(peak_length, dtype=dtype)
dt = 1
peaks['time'] = np.arange(peak_length) * dt
peaks['dt'] = dt

# Fill the peaks with random length data
for p in peaks:
length = np.random.randint(0, data_length)
p['length'] = length
wf = np.random.random(size=length)
p['data'][:length] = wf
if len(peaks):
# Compute sum area
peaks['area'] = np.sum(peaks['data'], axis=1)
return peaks

@settings(deadline=None)
@given(strat.lists(strat.integers(min_value=0, max_value=10),
min_size=8, max_size=8, unique=True),
Expand Down Expand Up @@ -77,3 +108,31 @@ def test_get_tight_coin(hits, channel):
m_hits_in_peak &= (hits_max_time <= (p_max_t + right))
n_channel = len(np.unique(hits[m_hits_in_peak]['channel']))
assert n_channel == tight_coin_channel[ind], f'Wrong number of tight channel got {tight_coin_channel[ind]}, but expectd {n_channel}' # noqa


@settings(max_examples=100, deadline=None)
@given(
# number of peaks
strategies.integers(min_value=0, max_value=20),
# length of the data field in the peaks
strategies.integers(min_value=2, max_value=20),
# Number of widths to compute
strategies.integers(min_value=2, max_value=10),
)
def test_compute_wf_attributes(peak_length, data_length, n_widths):
"""
Test strax.compute_wf_attribute
"""
peaks = get_filled_peaks(peak_length, data_length, n_widths)
wf = np.zeros((len(peaks), 10), dtype=np.float64)
q = np.zeros((len(peaks), 10), dtype=np.float64)

try:
q = compute_wf_attributes(peaks['data'], peaks['dt'], 10)
except AssertionError as e:
if "zero waveform" in str(e):
print("cannot compute with a zero waveform")
elif "more samples than the actual waveform" in str(e):
print("cannot compute with more samples than the actual waveform")

assert np.all(~np.isnan(q)) and np.all(~np.isnan(wf)), "attributes contains NaN values"