Skip to content

n_workers kwarg for FilterRecording and CommonReferenceRecording#4564

Open
galenlynch wants to merge 1 commit intoSpikeInterface:mainfrom
galenlynch:perf/parallel-filter-cmr
Open

n_workers kwarg for FilterRecording and CommonReferenceRecording#4564
galenlynch wants to merge 1 commit intoSpikeInterface:mainfrom
galenlynch:perf/parallel-filter-cmr

Conversation

@galenlynch
Copy link
Copy Markdown
Contributor

@galenlynch galenlynch commented Apr 24, 2026

Split out from draft PR #4562. Companion PR #4563 (FIR phase-shift + apply_raised_cosine_taper extract). No code dependency; either can land first.

Summary

Adds opt-in intra-chunk thread-parallelism to two preprocessors:

  • FilterRecording(n_workers=N) — channel-split sosfilt/sosfiltfilt
  • CommonReferenceRecording(n_workers=N) — time-split median/mean

Default n_workers=1 preserves existing behavior bit-for-bit. Each outer thread that calls get_traces() on a parallel-enabled segment gets its own inner ThreadPoolExecutor (per-caller-thread pool semantics), so the kwarg composes cleanly with TimeSeriesChunkExecutor outer parallelism — no shared-pool queueing pathology.

Throughout this PR text, CRE refers to SI's TimeSeriesChunkExecutor (formerly ChunkRecordingExecutor) — the chunk-parallel worker pool invoked internally by write_binary_recording, most sorters, run_node_pipeline, and related batch utilities.

Headline. 1M × 384 float32, 24-core x86_64 host:

Per-stage (this PR alone):

Stage Kwarg Stock Parallel Speedup
Bandpass 300–6000 Hz FilterRecording(n_workers=8) 8.59 s 3.20 s 2.69×
CMR median (global) CommonReferenceRecording(n_workers=16) 4.01 s 0.81 s 4.95×

Full pipeline (combined with companion FIR phase-shift PR; stock PS FFT dominates the pipeline so this PR alone only moves full-pipeline wall-clock ~1.1× — per-stage gains are dwarfed by the unchanged PS cost):

Pipeline config Stock FIR alone (companion PR) Combined (both PRs)
Direct get_traces(), int16 82.1 s 15.5 s (5.4×) 6.28 s (13.1×)
Direct get_traces(), f32 propagated 85.8 s 19.6 s (4.5×) 4.40 s (19.5×)

Combined numbers require both PRs merged.

Motivation

SI's existing TimeSeriesChunkExecutor (formerly ChunkRecordingExecutor) uses outer chunk-parallelism: each worker pulls a time slice, processes it serially. This is efficient for batch workflows (write_binary_recording, sorters, node pipelines) that own the control flow end-to-end.

It doesn't serve direct rec.get_traces(start, end) callers: interactive viewers, streaming consumers, custom loops with their own prefetch scheduling. Those callers can't reach CRE's parallelism without adopting its batch-processing API. This PR adds a second axis — intra-chunk thread parallelism, applied within a single get_traces() call — that any caller can opt into by passing a kwarg.

Secondary motivation: default SI users (n_jobs=1 is the default) get parallelism without needing to configure multi-process job_kwargs.

Changes

1. FilterRecording(n_workers=N) — channel-parallel SOS

File: src/spikeinterface/preprocessing/filter.py

  • New n_workers kwarg (default 1, preserving existing behavior).
  • When n_workers > 1, FilterRecordingSegment.get_traces splits the channel axis into contiguous blocks and runs scipy.signal.sosfiltfilt / sosfilt on each block in a per-caller-thread ThreadPoolExecutor.
  • Graceful fallback to serial when channel count is smaller than 2 * n_workers.
  • scipy's SOS C implementations release the GIL during per-column work, so Python-thread parallelism delivers real speedup.

2. CommonReferenceRecording(n_workers=N) — time-parallel reduction

File: src/spikeinterface/preprocessing/common_reference.py

  • New n_workers kwarg (default 1).
  • Only the common global-reference path (group_indices=None, reference="global", ref_channel_ids=None) is parallelized — every other configuration delegates to the existing logic unchanged.
  • When n_workers > 1, _parallel_reduce_axis1 splits the time axis into blocks and runs np.median / np.mean per block in a per-caller-thread pool.
  • Below min_block=8192 samples per thread the overhead dominates; falls back to serial automatically.

3. Per-caller-thread inner pool design

Each outer thread that calls get_traces() on a parallel-enabled segment gets its own lazy ThreadPoolExecutor, tracked in a weakref.WeakKeyDictionary keyed by the calling Thread object with a weakref.finalize(thread, pool.shutdown, wait=False) cleanup hook. A single shared inner pool would bottleneck under CRE — at n_jobs=24, n_workers=2, 24 outer threads submitting 2 tasks each into a 2-worker pool measured 3.36 s; per-caller pools measured 1.47 s (2.3× faster at the same thread budget). Keying by Thread (not thread-id integer) avoids thread-id reuse; the weakref + finalize pair ensures long-running processes don't accumulate zombie pools.

Correctness

Path Check Result
Parallel SOS vs stock np.allclose(rtol=1e-5) Pass — float-equivalent
Parallel median vs stock np.array_equal Pass — bit-identical
Parallel mean vs stock np.allclose(rtol=1e-5) Within 1 ULP (non-associative sum across block partitions)
Single caller reuses pool pool_a is pool_b Pass (tests/test_parallel_pool_semantics.py)
Two concurrent callers get distinct pools pool_a is not pool_b Pass (ibid.)

All existing tests for both modules pass unchanged.

Performance (reproducible)

benchmarks/preprocessing/bench_perf.py — synthetic NumpyRecording, 1M × 384 float32, measured on a 24-core x86_64 host.

Component-level (hot kernel only)

No SI plumbing — just raw scipy/numpy calls. Shows the ceiling for each kernel on this hardware:

Kernel Serial Threaded Speedup
scipy.signal.sosfiltfilt (1M × 384 float32) 7.80 s 2.67 s (8 threads) 2.92×
np.median(axis=1) (1M × 384 float32) 3.51 s 0.33 s (16 threads) 10.58×

Per-stage end-to-end (rec.get_traces())

Full SI preprocessing class through get_traces(), including margin fetch, buffer copies, casts, and subtraction:

Stage Stock (n_workers=1) Parallel Speedup Equivalence
Bandpass (5th-order Butterworth 300–6000 Hz, 1M × 384 float32) 8.59 s 3.20 s (n_workers=8) 2.69× matches stock within float32 tolerance
CMR median (global, 1M × 384 float32) 4.01 s 0.81 s (n_workers=16) 4.95× bitwise-identical to stock

End-to-end ratios are lower than component-level because the non-parallelizable glue (margin fetch, dtype cast, subtract) dilutes the speedup. Bandpass and CMR scale sub-linearly with thread count due to DRAM bandwidth saturation.

Pareto frontier under CRE: outer × inner

At chunk_duration="1s" (SI default), different splits of a ~24-thread compute budget on the BP+CMR pipeline, per-caller-thread pools:

Budget Config Time Notes
24 threads outer=24, inner=1 each 1.54 s clean, minimum thread count
192 threads outer=24, inner=8 each 1.42 s absolute peak, oversubscribed
24 threads outer=8, inner=3 each 1.53 s 8 outer, tied with outer=24 inner=1; ~⅓ RAM
12 threads outer=12, inner=1 each 1.59 s ~½ RAM of outer=24
12 threads outer=6, inner=2 each 1.75 s ~¼ RAM of outer=24
12 threads outer=4, inner=3 each 1.92 s ~⅙ RAM of outer=24
12 threads outer=1, inner=12 each 4.31 s single caller — sync overhead dominates

Key observations:

  • DRAM bandwidth saturates around 12 outer workers. outer=12, inner=1 is within 3% of outer=24, inner=1; doubling cores past that gives diminishing returns.
  • Moderate outer + small inner is RAM-efficient. outer=6, inner=2 reaches 92% of peak at ~¼ the RAM; outer=8, inner=3 matches outer=12, inner=1 at ⅔ the RAM.

CRE interaction tables

For BP specifically (inner pool = 8, matching CRE n_jobs=8):

Config Time Speedup Parallelism axis
stock, CRE n=1 (baseline) 7.42 s 1.00×
stock, CRE n=8 thread 1.40 s 5.29× outer only
n_workers=8, CRE n=1 3.18 s 2.33× inner only
n_workers=8, CRE n=8 thread 1.24 s 6.00× both

For CMR (inner pool = 16, exceeds CRE n_jobs=8):

Config Time Speedup Parallelism axis
stock, CRE n=1 (baseline) 3.98 s 1.00×
stock, CRE n=8 thread 0.61 s 6.47× outer only
n_workers=16, CRE n=1 1.58 s 2.52× inner only
n_workers=16, CRE n=8 thread 0.36 s 11.01× both

Tuning guidance

Recommended configurations by caller posture:

Caller Recommended Expected
Direct get_traces() on large windows (viewer, streaming consumer) n_workers=core_count // 8 (more gives diminishing returns) 2–3× vs serial
Default SI user (n_jobs=1) n_workers=8–16 as per-stage 2.7–5× per stage
CRE n_jobs ≥ 12, RAM-rich n_workers=1 (outer already near DRAM ceiling) 0–5% gain
CRE n_jobs < 12, RAM-constrained n_workers = cores // n_jobs + a margin Within 10% of peak at significant RAM savings
Peak absolute throughput n_jobs=core_count + n_workers=4–8 (oversubscribed) ~8% above outer-only

Compatibility

  • No default behavior changes. n_workers=1 preserves existing semantics exactly.
  • Round-trip dumpability. _kwargs dicts updated on both preprocessors; save() / load() round-trip the new kwargs correctly.
  • No new deps. Uses stdlib concurrent.futures.ThreadPoolExecutor, threading, weakref.
  • Propagates through the bandpass_filter, highpass_filter, filter, notch_filter, common_reference wrapper functions via **filter_kwargs.
  • Long-running processes safe: per-caller pools are cleaned up when the calling thread is GC'd.

Review guide

  1. filter.py: _apply_sos helper + n_workers kwarg plumbing + WeakKeyDictionary pool map + weakref.finalize cleanup.
  2. common_reference.py: _parallel_reduce_axis1 helper + same pool-ownership pattern. Parallelization guarded to the global-reference hot path only.
  3. tests/test_parallel_pool_semantics.py: single-caller reuse + concurrent-caller isolation contract tests for both preprocessors.
  4. Existing tests: correctness (bit-identical median, float-close sos), thread-pool reuse across calls, single-worker-fallback for small channel counts.

Companion PR

An independent companion PR #4563 adds a sinc-FIR alternative to PhaseShiftRecording with ~100× per-stage speedup plus memory win. No code dependency between the two; either can land first. Combined, they give 13–20× on a typical PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording chain for direct get_traces() callers, or ~3× on top of existing CRE parallelism.

Checklist

  • Existing preprocessing tests pass
  • New tests cover the per-caller-thread pool contract
  • Benchmark script with reproducible fixtures
  • Dumpable recordings (_kwargs updated)
  • Docstrings explain when the kwargs help / don't help
  • WeakKeyDictionary cleanup prevents pool leaks in long-running processes

Adds opt-in intra-chunk thread-parallelism to two preprocessors:
channel-split sosfilt/sosfiltfilt in FilterRecording, time-split
median/mean in CommonReferenceRecording.  Default n_workers=1 preserves
existing behavior.

Per-caller-thread inner pools
-----------------------------
Each outer thread that calls ``get_traces()`` on a parallel-enabled segment
gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary``
keyed by the calling ``Thread`` object.  Rationale:

* Avoids the shared-pool queueing pathology that would occur if N outer
  workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted
  into a single shared pool with fewer max_workers than outer callers.
  Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at
  3.36 s on the test pipeline; per-caller pools: 1.47 s.
* Keying by the Thread object (not thread-id integer) avoids the
  thread-id-reuse hazard: thread IDs can be reused after a thread dies,
  which would cause a new thread to silently inherit a dead thread's
  pool.
* WeakKeyDictionary + weakref.finalize ensures automatic shutdown of
  the inner pool when the calling thread is garbage-collected.  The
  finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the
  finalizer thread; in-flight tasks would be cancelled, but the owning
  thread submits+joins synchronously, so none exist when it exits.

When useful
-----------
* Direct ``get_traces()`` callers (interactive viewers, streaming
  consumers, mipmap-zarr tile builders) that don't use
  ``TimeSeriesChunkExecutor``.
* Default SI users who haven't tuned job_kwargs.
* RAM-constrained deployments that can't crank ``n_jobs`` to core count:
  on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of
  ``n_jobs=24, n_workers=1`` at ~1/4 the RAM.

Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine)
---------------------------------------------------------------------------
  === Component-level (scipy/numpy only) ===
  sosfiltfilt serial → 8 threads:   7.80 s →  2.67 s (2.92x)
  np.median serial   → 16 threads:  3.51 s →  0.33 s (10.58x)

  === Per-stage end-to-end (rec.get_traces) ===
  Bandpass (5th-order, 300-6k Hz): 8.59 s →  3.20 s (2.69x)
  CMR median (global):             4.01 s →  0.81 s (4.95x)

  === CRE outer × inner Pareto, per-caller pools ===
  outer=24, inner=1 each:          1.54 s  (100% of peak)
  outer=24, inner=8 each:          1.42 s  (108% of peak; oversubscribed)
  outer=12, inner=1 each:          1.59 s  (97%, ~1/2 RAM of outer=24)
  outer=6,  inner=2 each:          1.75 s  (92%, ~1/4 RAM of outer=24)
  outer=4,  inner=6 each:          1.83 s  (87%, ~1/6 RAM with 24 threads)

Tests
-----
New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread
contract: single caller reuses one pool; concurrent callers get distinct
pools.  Existing bandpass + CMR tests still pass.

Independent of the companion FIR phase-shift PR (perf/phase-shift-fir);
the two can land in either order.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant