Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to instantiate NWB extractors from nwbfile object #2506

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 99 additions & 37 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ class NwbRecordingExtractor(BaseRecording):
use_pynwb: bool, default: False
Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py
to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.
nwbfile: NWBFile or None, default: None
The NWBFile object. If provided, the extractor will use this object instead of reading the file.

Returns
-------
Expand Down Expand Up @@ -503,11 +505,12 @@ def __init__(
cache: bool = False,
storage_options: dict | None = None,
use_pynwb: bool = False,
nwbfile: NWBFile | None = None,
):
if file_path is not None and file is not None:
raise ValueError("Provide either file_path or file, not both")
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")
if file_path is not None and file is not None and nwbfile is not None:
raise ValueError("Provide either file_path or file or nwbfile, not both")
if file_path is None and file is None and nwbfile is None:
raise ValueError("Provide either file_path or file or nwbfile")

if electrical_series_name is not None:
warning_msg = (
Expand All @@ -527,13 +530,26 @@ def __init__(
self.storage_options = storage_options
self.electrical_series_path = electrical_series_path

if self.stream_mode is None and file is None:
self.backend = _get_backend_from_local_file(file_path)
if nwbfile is None:
if self.stream_mode is None:
self.backend = _get_backend_from_local_file(file_path)
else:
if self.stream_mode == "zarr":
self.backend = "zarr"
else:
self.backend = "hdf5"
else:
if self.stream_mode == "zarr":
self.backend = "zarr"
from pynwb import NWBHDF5IO

use_pynwb = True
io = nwbfile.get_read_io()
if io is not None:
if isinstance(io, NWBHDF5IO):
self.backend = "hdf5"
else:
self.backend = "zarr"
else:
self.backend = "hdf5"
raise FileNotFoundError("In memory NWBFile is not supported")

# extract info
if use_pynwb:
Expand All @@ -548,7 +564,9 @@ def __init__(
dtype,
segment_data,
times_kwargs,
) = self._fetch_recording_segment_info_pynwb(file, cache, load_time_vector, samples_for_rate_estimation)
) = self._fetch_recording_segment_info_pynwb(
file, cache, load_time_vector, samples_for_rate_estimation, nwbfile=nwbfile
)
else:
(
channel_ids,
Expand Down Expand Up @@ -606,6 +624,12 @@ def __init__(
# not json serializable if file arg is provided
self._serializability["json"] = False

if nwbfile is not None:
# NWBFile is not serializable
self._serializability["memory"] = False
self._serializability["json"] = False
self._serializability["pickle"] = False
Comment on lines +631 to +635
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CodyCBakerPhD note that this will make the object unsuitable for parallel processing (so this approach is not recommended IMO)


if storage_options is not None and stream_mode == "zarr":
warnings.warn(
"The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, "
Expand All @@ -626,6 +650,7 @@ def __init__(
"cache": cache,
"stream_cache_path": stream_cache_path,
"file": file,
"nwbfile": nwbfile,
}

def __del__(self):
Expand All @@ -641,15 +666,20 @@ def __del__(self):
if io is not None:
io.close()

def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation):
self._nwbfile = read_nwbfile(
backend=self.backend,
file_path=self.file_path,
file=file,
stream_mode=self.stream_mode,
cache=cache,
stream_cache_path=self.stream_cache_path,
)
def _fetch_recording_segment_info_pynwb(
self, file, cache, load_time_vector, samples_for_rate_estimation, nwbfile=None
):
if nwbfile is None:
self._nwbfile = read_nwbfile(
backend=self.backend,
file_path=self.file_path,
file=file,
stream_mode=self.stream_mode,
cache=cache,
stream_cache_path=self.stream_cache_path,
)
else:
self._nwbfile = nwbfile
electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path)
# The indices in the electrode table corresponding to this electrical series
electrodes_indices = electrical_series.electrodes.data[:]
Expand Down Expand Up @@ -912,7 +942,7 @@ class NwbSortingExtractor(BaseSorting):
"""Load an NWBFile as a SortingExtractor.
Parameters
----------
file_path: str or Path
file_path: str or Path or None, default: None
Path to NWB file.
electrical_series_path: str or None, default: None
The name of the ElectricalSeries (if multiple ElectricalSeries are present).
Expand Down Expand Up @@ -949,6 +979,8 @@ class NwbSortingExtractor(BaseSorting):
use_pynwb: bool, default: False
Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py
to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.
nwbfile: NWBFile or None, default: None
If a NWBFile is provided, the extractor will use it instead of reading the file from disk.

Returns
-------
Expand All @@ -963,7 +995,7 @@ class NwbSortingExtractor(BaseSorting):

def __init__(
self,
file_path: str | Path,
file_path: str | Path | None = None,
electrical_series_path: str | None = None,
sampling_frequency: float | None = None,
samples_for_rate_estimation: int = 1_000,
Expand All @@ -976,6 +1008,7 @@ def __init__(
cache: bool = False,
storage_options: dict | None = None,
use_pynwb: bool = False,
nwbfile: NWBFile | None = None,
):
self.stream_mode = stream_mode
self.stream_cache_path = stream_cache_path
Expand All @@ -986,13 +1019,26 @@ def __init__(
self.storage_options = storage_options
self.units_table = None

if self.stream_mode is None:
self.backend = _get_backend_from_local_file(file_path)
if nwbfile is None:
if self.stream_mode is None:
self.backend = _get_backend_from_local_file(file_path)
else:
if self.stream_mode == "zarr":
self.backend = "zarr"
else:
self.backend = "hdf5"
else:
if self.stream_mode == "zarr":
self.backend = "zarr"
from pynwb import NWBHDF5IO

use_pynwb = True
io = nwbfile.get_read_io()
if io is not None:
if isinstance(io, NWBHDF5IO):
self.backend = "hdf5"
else:
self.backend = "zarr"
else:
self.backend = "hdf5"
raise FileNotFoundError("In memory NWBFile is not supported")

if use_pynwb:
try:
Expand All @@ -1001,7 +1047,10 @@ def __init__(
raise ImportError(self.installation_mesg)

unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_pynwb(
unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache
unit_table_path=unit_table_path,
samples_for_rate_estimation=samples_for_rate_estimation,
cache=cache,
nwbfile=nwbfile,
)
else:
unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_backend(
Expand Down Expand Up @@ -1034,6 +1083,12 @@ def __init__(
if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

if nwbfile is not None:
# NWBFile is not serializable
self._serializability["memory"] = False
self._serializability["json"] = False
self._serializability["pickle"] = False

if storage_options is not None and stream_mode == "zarr":
warnings.warn(
"The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, "
Expand All @@ -1054,6 +1109,7 @@ def __init__(
"storage_options": storage_options,
"load_unit_properties": load_unit_properties,
"t_start": self.t_start,
"nwbfile": nwbfile,
}

def __del__(self):
Expand All @@ -1070,17 +1126,23 @@ def __del__(self):
io.close()

def _fetch_sorting_segment_info_pynwb(
self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False
self,
unit_table_path: str = None,
samples_for_rate_estimation: int = 1000,
cache: bool = False,
nwbfile: NWBFile | None = None,
):
self._nwbfile = read_nwbfile(
backend=self.backend,
file_path=self.file_path,
stream_mode=self.stream_mode,
cache=cache,
stream_cache_path=self.stream_cache_path,
storage_options=self.storage_options,
)

if nwbfile is None:
self._nwbfile = read_nwbfile(
backend=self.backend,
file_path=self.file_path,
stream_mode=self.stream_mode,
cache=cache,
stream_cache_path=self.stream_cache_path,
storage_options=self.storage_options,
)
else:
self._nwbfile = nwbfile
timestamps = None
if self.provided_or_electrical_series_sampling_frequency is None:
# defines the electrical series from where the sorting came from
Expand Down
Loading
Loading