Skip to content

Further additions to SI PR 4502#29

Draft
grahamfindlay wants to merge 10 commits into
alejoe91:get-unit-spike-trainsfrom
grahamfindlay:dev-pr4502
Draft

Further additions to SI PR 4502#29
grahamfindlay wants to merge 10 commits into
alejoe91:get-unit-spike-trainsfrom
grahamfindlay:dev-pr4502

Conversation

@grahamfindlay
Copy link
Copy Markdown

@grahamfindlay grahamfindlay commented Apr 17, 2026

  • Improve and add tests for sorting_tools.is_spike_vector_sorted()
  • Leverage fact that Phy/KS extractors are always single-segment to improve ._compute_and_cache_spike_vector().
  • Leverage when UnitSelectionSortings are single-segment to improve _compute_and_cache_spike_vector().

…improve _compute_and_cache_spike_vector()

Gains come from:

- Avoiding an unnecessary `np.concatenate()`
- Dropping `segment_index` from lexsort keys.

Another gain, not related to the single-segment thing, is that `np.empty()` can replace `np.zeros()`.
…ume_single_segment` shortcut.

Just checking whether the spike vector is sorted was allocating quite a bit of memory (for the diff arrays; ~3.25GB for 100M spikes). This chunked version only requires a constant amount of memory (~40MB).

This version also adds early stopping, of sorts. That was accidental consequence of the chunking tbh, but it makes sense that it beats diff'ing every element in the spike vector (the 1M spike benchmark below is a good indicator).

We also don't have to waste time/space checking the segment index if we know it is a single-segment vector.

Depending on how pathological the spike vector is, and whether or not it is single-segment, the speedup is ~1.5x-2.0x for 50M+ spikes, and a bit larger (~1.8-2.5x) for 1M spikes.
…ing._compute_and_cache_spike_vector()`.

We can easily check whether the USS is single-segement, and improve both the sortedness check and the lexsorting if so.
Two main improvements:

1. The cluster_id -> unit_index mapping now uses
   a dense `cluster_to_unit` lookup table: O(N).
   Replaces `np.searchsorted`, which was O(N log M).
   On 392M spikes and 342 units, this is 27s -> 3s.
2. A numba kernel builds the spike vector in one
   pass. This is possible because (a) the input arrays
   are already sorted by sample_index, and the only
   remaining tie-breaking is by unit_index within
   each sample_index run, and (b) the runs are
   short (single-digit number of spikes ) and rare
   (single-digit percentage of total spikes).
   On 392M spikes and 342 units, ~170s -> ~6-10s.
This fixes a performance regression (relative to
my prototype, not relative to a previous commit)
in Phy/Kilosort's `get_unit_spike_trains()`.

On 392M spikes and 342 clusters, this takes the
numba implementation from ~35s down to ~5s, and
the numpy implementation from ~110s down to ~80s.
@grahamfindlay
Copy link
Copy Markdown
Author

Note to self (mostly): The optimization used for Phy/Kilosort get_unit_spike_trains() is probably also immediately applicable to Tridesclous, ALF, HerdingSpikes, MDA, and YASS extractors. We should probably promote the LUT + counting sort method (both numba and numpy fallback implementations) to sorting_tools.py and call it in those other extractors too -- maybe save that for a separate PR though.

Use Samuel's trick from PR 4579:
`__init__()` doesn't waste time trying to remove
spikes and unit ids from "bad clusters" from the
full flat (`.npy``) arrays on load (`read_phy()``),
if there aren't any bad clusters to begin with.
This uses essentially the same dense LUT +
flat-view buffer trick as the Phy/Kilosort
`_compute_and_cache_spike_vector` method, in order
to allow remapping and filtering the parent's
spike vector in a single pass, without needing to
create any intermediate allocations.

On 392M spikes, selecting 258 out of 342 units in
a single pass, starting from a cached parent
spike vector, this reduces the additional time
needed to get the USS spike vector to ~5.5s.
…spike_vector()

Check if the user requested an "identity selection": all parent units, in
parent order, possibly renamed (the spike vector uses unit _index_, and
renaming doesn't affect that). If so, the cached parent spike vector is
identical to the one we want, so just share the reference and skip the rest.
`to_spike_vector()` already returns a (canonical)
vector whose sample_index is ascending within each
segment. The two supported lexsort keys
only differ from the canonical order by how they
grouping spikes into (unit, segment) buckets.
Sample order within each bucket is already correct
in the canonical vector. No re-sorting necessary!
Perfect for a (linear time!) counting sort.

A numba kernel based on textbook implementation
was fast enough that I didn't try alternatives.

Plus, the same array striding tricks used to
avoid intermediate copies are reused from recent
optimization commits, before passing to the
numba kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant