From 73402a3647d4066decdb210716985d77aead268e Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Thu, 4 May 2023 11:28:02 -0400 Subject: [PATCH] Fix argsort inside numba.jit using kind='mergesort' (#721) * Fix argsort inside numba.jit using kind='mergesort' See https://numba.readthedocs.io/en/stable/reference/numpysupported.html#other-methods * Provide `sort_kind` as an argument of _sort_by_time_and_channel --- strax/processing/general.py | 4 ++-- strax/processing/statistics.py | 2 +- strax/run_selection.py | 2 +- strax/utils.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/strax/processing/general.py b/strax/processing/general.py index 483c8cbf6..d46ebe267 100644 --- a/strax/processing/general.py +++ b/strax/processing/general.py @@ -39,7 +39,7 @@ def sort_by_time(x): return x @numba.jit(nopython=True, nogil=True, cache=True) -def _sort_by_time_and_channel(x, channel, max_channel_plus_one): +def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind='mergesort'): """ Assumes you have no more than 10k channels, and records don't span more than 11 days. @@ -49,7 +49,7 @@ def _sort_by_time_and_channel(x, channel, max_channel_plus_one): # I couldn't get fast argsort on multiple keys to work in numba # So, let's make a single key... sort_key = (x['time'] - x['time'].min()) * max_channel_plus_one + channel - sort_i = np.argsort(sort_key) + sort_i = np.argsort(sort_key, kind=sort_kind) return x[sort_i] diff --git a/strax/processing/statistics.py b/strax/processing/statistics.py index 3ed848094..98b37c2bd 100644 --- a/strax/processing/statistics.py +++ b/strax/processing/statistics.py @@ -40,7 +40,7 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff 'with a total probability of less-equal 0.') # Need an index which sorted by amplitude - max_to_min = np.argsort(data)[::-1] + max_to_min = np.argsort(data, kind='mergesort')[::-1] lowest_sample_seen = np.inf for j in range(1, len(data)): diff --git a/strax/run_selection.py b/strax/run_selection.py index 7982c1dd8..9d0b30f4e 100644 --- a/strax/run_selection.py +++ b/strax/run_selection.py @@ -306,7 +306,7 @@ def define_run(self: strax.Context, if isinstance(data, (pd.DataFrame, np.ndarray)): if isinstance(data, np.ndarray): data = pd.DataFrame.from_records(data) - + # strax.endtime does not work with DataFrames due to numba if 'endtime' in data.columns: end = data['endtime'] diff --git a/strax/utils.py b/strax/utils.py index 912061ce0..20f86a88a 100644 --- a/strax/utils.py +++ b/strax/utils.py @@ -543,7 +543,7 @@ def multi_run(exec_function, run_ids, *args, if throw_away_result: continue result = f.result() - + # Append the run id column if add_run_id_field: ids = np.array([_run_id] * len(result),