Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
class BaseSorter:
"""Base Sorter object."""

sorter_name = '' # convenience for reporting
sorter_name = "" # convenience for reporting
compiled_name = None
SortingExtractor_Class = None # convenience to get the extractor
requires_locations = False
docker_requires_gpu = False
gpu_capability = 'not-supported'
compatible_with_parallel = {'loky': True, 'multiprocessing': True, 'threading': True}

_default_params = {}
_params_description = {}
sorter_description = ""
Expand Down Expand Up @@ -289,6 +290,10 @@ def check_compiled(cls):
if retcode != 0:
return False
return True

@classmethod
def use_gpu(cls, params):
return cls.gpu_capability != 'not-supported'

#############################################

Expand Down
6 changes: 5 additions & 1 deletion spikeinterface/sorters/ironclust/ironclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class IronClustSorter(BaseSorter):
ironclust_path: Union[str, None] = os.getenv('IRONCLUST_PATH', None)

requires_locations = True
docker_requires_gpu = True
gpu_capability = 'nvidia-optional'

_default_params = {
'detect_sign': -1, # Use -1, 0, or 1, depending on the sign of the spikes in the recording
Expand Down Expand Up @@ -142,6 +142,10 @@ def get_sorter_version(cls):
version = d['version']
return version
return 'unknown'

@classmethod
def use_gpu(cls, params):
return params["fGpu"]

@staticmethod
def set_ironclust_path(ironclust_path: PathType):
Expand Down
6 changes: 5 additions & 1 deletion spikeinterface/sorters/kilosort/kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class KilosortSorter(KilosortBase, BaseSorter):
compiled_name: str = 'ks_compiled'
kilosort_path: Union[str, None] = os.getenv('KILOSORT_PATH', None)
requires_locations = False
docker_requires_gpu = True
requires_gpu = 'nvidia-optional'

_default_params = {
'detect_threshold': 6,
Expand Down Expand Up @@ -89,6 +89,10 @@ def get_sorter_version(cls):
return 'unknown'
else:
return 'git-' + commit

@classmethod
def use_gpu(cls, params):
return params["useGPU"]

@classmethod
def set_kilosort_path(cls, kilosort_path: str):
Expand Down
1 change: 0 additions & 1 deletion spikeinterface/sorters/kilosort2/kilosort2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class Kilosort2Sorter(KilosortBase, BaseSorter):
compiled_name: str = 'ks2_compiled'
kilosort2_path: Union[str, None] = os.getenv('KILOSORT2_PATH', None)
requires_locations = False
docker_requires_gpu = True

_default_params = {
'detect_threshold': 6,
Expand Down
1 change: 0 additions & 1 deletion spikeinterface/sorters/kilosort2_5/kilosort2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter):
compiled_name: str = 'ks2_5_compiled'
kilosort2_5_path: Union[str, None] = os.getenv('KILOSORT2_5_PATH', None)
requires_locations = False
docker_requires_gpu = True

_default_params = {
'detect_threshold': 6,
Expand Down
1 change: 0 additions & 1 deletion spikeinterface/sorters/kilosort3/kilosort3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class Kilosort3Sorter(KilosortBase, BaseSorter):
compiled_name: str = 'ks3_compiled'
kilosort3_path: Union[str, None] = os.getenv('KILOSORT3_PATH', None)
requires_locations = False
docker_requires_gpu = True

_default_params = {
'detect_threshold': 6,
Expand Down
1 change: 1 addition & 0 deletions spikeinterface/sorters/kilosortbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class KilosortBase:
* _run_from_folder
* _get_result_from_folder
"""
gpu_capability = 'nvidia-required'

@staticmethod
def _generate_channel_map_file(recording, output_folder):
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/sorters/pykilosort/pykilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PyKilosortSorter(BaseSorter):

sorter_name = 'pykilosort'
requires_locations = False
docker_requires_gpu = True
gpu_capability = 'nvidia-required'
compatible_with_parallel = {'loky': True, 'multiprocessing': False, 'threading': False}

_default_params = {
Expand Down
29 changes: 23 additions & 6 deletions spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..version import version as si_version
from spikeinterface.core.core_tools import check_json, recursive_path_modifier, is_dict_extractor
from .sorterlist import sorter_dict
from .utils import SpikeSortingError
from .utils import SpikeSortingError, has_nvidia


_common_param_doc = """
Expand Down Expand Up @@ -160,12 +160,14 @@ class ContainerClient:
def __init__(self, mode, container_image, volumes, extra_kwargs):
assert mode in ('docker', 'singularity')
self.mode = mode
container_requires_gpu = extra_kwargs.get(
'container_requires_gpu', None)

if mode == 'docker':
import docker
client = docker.from_env()
if extra_kwargs.get('requires_gpu', False):
extra_kwargs.pop('requires_gpu')
if container_requires_gpu is not None:
extra_kwargs.pop('container_requires_gpu')
extra_kwargs["device_requests"] = [
docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])]

Expand Down Expand Up @@ -198,7 +200,7 @@ def __init__(self, mode, container_image, volumes, extra_kwargs):
options=['--bind', singularity_bind]

# gpu options
if extra_kwargs.get('requires_gpu', False):
if container_requires_gpu:
# only nvidia at the moment
options += ['--nv']

Expand Down Expand Up @@ -306,8 +308,23 @@ def run_sorter_container(sorter_name, recording, mode, container_image, output_f
install_si_from_source = False

extra_kwargs = {}
if SorterClass.docker_requires_gpu:
extra_kwargs['requires_gpu'] = True
use_gpu = SorterClass.use_gpu(sorter_params)
gpu_capability = SorterClass.gpu_capability

if use_gpu:
if gpu_capability == 'nvidia-required':
assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available"
extra_kwargs['container_requires_gpu'] = True
elif gpu_capability == 'nvidia-optional':
if has_nvidia():
extra_kwargs['container_requires_gpu'] = True
else:
if verbose:
print(f"{SorterClass.sorter_name} supports GPU, but no GPU is available.\n"
f"Running the sorter without GPU")
else:
# TODO: make opencl machanism
raise NotImplementedError("Only nvidia support is available")

container_client = ContainerClient(mode, container_image, volumes, extra_kwargs)
if verbose:
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/sorters/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .shellscript import ShellScript
from .misc import (SpikeSortingError, get_git_commit)
from .misc import (SpikeSortingError, get_git_commit, has_nvidia)
11 changes: 11 additions & 0 deletions spikeinterface/sorters/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,14 @@ def get_git_commit(git_folder, shorten=True):
except:
commit = None
return commit


def has_nvidia():
"""
Checks if the machine has nvidia capability.
"""
try:
check_output('nvidia-smi')
return True
except Exception: # this command not being found can raise quite a few different errors depending on the configuration
return False
2 changes: 1 addition & 1 deletion spikeinterface/sorters/yass/yass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class YassSorter(BaseSorter):

sorter_name = 'yass'
requires_locations = False
docker_requires_gpu = True
gpu_capability = 'nvidia-required'

# #################################################

Expand Down