Skip to content

Commit

Permalink
Fix argsort inside numba.jit using kind='mergesort' (#721)
Browse files Browse the repository at this point in the history
* Fix argsort inside numba.jit using kind='mergesort'

See https://numba.readthedocs.io/en/stable/reference/numpysupported.html#other-methods

* Provide `sort_kind` as an argument of _sort_by_time_and_channel
  • Loading branch information
dachengx committed May 4, 2023
1 parent ce42749 commit 73402a3
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions strax/processing/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def sort_by_time(x):
return x

@numba.jit(nopython=True, nogil=True, cache=True)
def _sort_by_time_and_channel(x, channel, max_channel_plus_one):
def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind='mergesort'):
"""
Assumes you have no more than 10k channels, and records don't span
more than 11 days.
Expand All @@ -49,7 +49,7 @@ def _sort_by_time_and_channel(x, channel, max_channel_plus_one):
# I couldn't get fast argsort on multiple keys to work in numba
# So, let's make a single key...
sort_key = (x['time'] - x['time'].min()) * max_channel_plus_one + channel
sort_i = np.argsort(sort_key)
sort_i = np.argsort(sort_key, kind=sort_kind)
return x[sort_i]


Expand Down
2 changes: 1 addition & 1 deletion strax/processing/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff
'with a total probability of less-equal 0.')

# Need an index which sorted by amplitude
max_to_min = np.argsort(data)[::-1]
max_to_min = np.argsort(data, kind='mergesort')[::-1]

lowest_sample_seen = np.inf
for j in range(1, len(data)):
Expand Down
2 changes: 1 addition & 1 deletion strax/run_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def define_run(self: strax.Context,
if isinstance(data, (pd.DataFrame, np.ndarray)):
if isinstance(data, np.ndarray):
data = pd.DataFrame.from_records(data)

# strax.endtime does not work with DataFrames due to numba
if 'endtime' in data.columns:
end = data['endtime']
Expand Down
2 changes: 1 addition & 1 deletion strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def multi_run(exec_function, run_ids, *args,
if throw_away_result:
continue
result = f.result()

# Append the run id column
if add_run_id_field:
ids = np.array([_run_id] * len(result),
Expand Down

0 comments on commit 73402a3

Please sign in to comment.