Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -717,6 +718,15 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["noise_levels"]

def _handle_backward_compatibility_on_load(self):
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
# now it is handle more explicitly using random_slices_kwargs=dict()
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
if key in self.params:
if "random_slices_kwargs" not in self.params:
self.params["random_slices_kwargs"] = dict()
self.params["random_slices_kwargs"][key] = self.params.pop(key)


register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()
35 changes: 21 additions & 14 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
return n_jobs


def chunk_duration_to_chunk_size(chunk_duration, recording):
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
return chunk_size


def ensure_chunk_size(
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
Expand Down Expand Up @@ -231,18 +247,7 @@ def ensure_chunk_size(
num_channels = recording.get_num_channels()
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
elif chunk_duration is not None:
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
Expand Down Expand Up @@ -382,11 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self):
def run(self, all_chunks=None):
"""
Runs the defined jobs.
"""
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if all_chunks is None:
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand Down
Loading
Loading