Skip to content

Commit

Permalink
Merge pull request #2546 from chrishalcrow/waveform_backwards_compat
Browse files Browse the repository at this point in the history
Add more backwards compatibly for MockWaveformExtractor
  • Loading branch information
alejoe91 committed May 10, 2024
2 parents cd55f91 + 641ca35 commit ba3dfb4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Expand Up @@ -365,7 +365,12 @@ def _compute_and_append_from_waveforms(self, operators):

# spikes = self.sorting_analyzer.sorting.to_spike_vector()
# some_spikes = spikes[self.sorting_analyzer.random_spikes_indices]

assert self.sorting_analyzer.has_extension(
"random_spikes"
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()

for unit_index, unit_id in enumerate(unit_ids):
spike_mask = some_spikes["unit_index"] == unit_index
wfs = waveforms[spike_mask, :, :]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Expand Up @@ -116,7 +116,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
# compute
sorting_analyzer.compute("dummy", param1=5.5)
# equivalent
compute_dummy(sorting_analyzer, param1=5.5)
compute_dummy(sorting_analyzer=sorting_analyzer, param1=5.5)
ext = sorting_analyzer.get_extension("dummy")
assert ext is not None
assert ext.params["param1"] == 5.5
Expand Down
Expand Up @@ -138,6 +138,9 @@ def has_waveforms(self) -> bool:
def delete_waveforms(self) -> None:
self.sorting_analyzer.delete_extension("waveforms")

def delete_extension(self, extension) -> None:
self.sorting_analyzer.delete_extension()

@property
def recording(self) -> BaseRecording:
return self.sorting_analyzer.recording
Expand Down Expand Up @@ -222,6 +225,9 @@ def get_recording_property(self, key) -> np.ndarray:
def get_sorting_property(self, key) -> np.ndarray:
return self.sorting_analyzer.get_sorting_property(key)

def get_available_extension_names(self):
return self.sorting_analyzer.get_loaded_extension_names()

@property
def sparsity(self):
return self.sorting_analyzer.sparsity
Expand All @@ -231,17 +237,31 @@ def folder(self):
if self.sorting_analyzer.format != "memory":
return self.sorting_analyzer.folder

@property
def format(self):
if self.sorting_analyzer.format == "binary_folder":
return "binary"
else:
return self.sorting_analyzer.format

def has_extension(self, extension_name: str) -> bool:
return self.sorting_analyzer.has_extension(extension_name)

def select_units(self, unit_ids):
return self.sorting_analyzer.select_units(unit_ids)

def get_sampled_indices(self, unit_id):
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) with a complex dtype as follow
selected_spikes = []
for segment_index in range(self.get_num_segments()):
# inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index)
assert self.sorting_analyzer.has_extension(
"random_spikes"
), "get_sampled_indices() requires the 'random_spikes' extension."
inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train(
unit_id, segment_index
)

sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")])
sampled_index["spike_index"] = inds
sampled_index["segment_index"][:] = segment_index
Expand All @@ -260,7 +280,13 @@ def get_waveforms(
# lazy and cache are ingnored
ext = self.sorting_analyzer.get_extension("waveforms")
unit_index = self.sorting.id_to_index(unit_id)

assert self.sorting_analyzer.has_extension(
"random_spikes"
), "get_sampled_indices() requires the 'random_spikes' extension."

some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()

spike_mask = some_spikes["unit_index"] == unit_index
wfs = ext.data["waveforms"][spike_mask, :, :]

Expand Down Expand Up @@ -484,6 +510,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext = ComputeRandomSpikes(sorting_analyzer)
ext.params = dict()
ext.data = dict(random_spikes_indices=random_spikes_indices)
sorting_analyzer.extensions["random_spikes"] = ext

ext = ComputeWaveforms(sorting_analyzer)
ext.params = dict(
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/postprocessing/correlograms.py
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer

from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor

try:
import numba

Expand Down Expand Up @@ -87,6 +89,10 @@ def compute_correlograms(
bin_ms: float = 1.0,
method: str = "auto",
):

if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor):
sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method
Expand Down

0 comments on commit ba3dfb4

Please sign in to comment.