diff --git a/spikeinterface/sorters/basesorter.py b/spikeinterface/sorters/basesorter.py index d90b0cceec..b56ad6ab43 100644 --- a/spikeinterface/sorters/basesorter.py +++ b/spikeinterface/sorters/basesorter.py @@ -22,12 +22,13 @@ class BaseSorter: """Base Sorter object.""" - sorter_name = '' # convenience for reporting + sorter_name = "" # convenience for reporting compiled_name = None SortingExtractor_Class = None # convenience to get the extractor requires_locations = False - docker_requires_gpu = False + gpu_capability = 'not-supported' compatible_with_parallel = {'loky': True, 'multiprocessing': True, 'threading': True} + _default_params = {} _params_description = {} sorter_description = "" @@ -289,6 +290,10 @@ def check_compiled(cls): if retcode != 0: return False return True + + @classmethod + def use_gpu(cls, params): + return cls.gpu_capability != 'not-supported' ############################################# diff --git a/spikeinterface/sorters/ironclust/ironclust.py b/spikeinterface/sorters/ironclust/ironclust.py index cfa01663f9..0d00237e94 100644 --- a/spikeinterface/sorters/ironclust/ironclust.py +++ b/spikeinterface/sorters/ironclust/ironclust.py @@ -34,7 +34,7 @@ class IronClustSorter(BaseSorter): ironclust_path: Union[str, None] = os.getenv('IRONCLUST_PATH', None) requires_locations = True - docker_requires_gpu = True + gpu_capability = 'nvidia-optional' _default_params = { 'detect_sign': -1, # Use -1, 0, or 1, depending on the sign of the spikes in the recording @@ -142,6 +142,10 @@ def get_sorter_version(cls): version = d['version'] return version return 'unknown' + + @classmethod + def use_gpu(cls, params): + return params["fGpu"] @staticmethod def set_ironclust_path(ironclust_path: PathType): diff --git a/spikeinterface/sorters/kilosort/kilosort.py b/spikeinterface/sorters/kilosort/kilosort.py index ea3b174f41..e72e6b0eee 100644 --- a/spikeinterface/sorters/kilosort/kilosort.py +++ b/spikeinterface/sorters/kilosort/kilosort.py @@ -30,7 +30,7 @@ class KilosortSorter(KilosortBase, BaseSorter): compiled_name: str = 'ks_compiled' kilosort_path: Union[str, None] = os.getenv('KILOSORT_PATH', None) requires_locations = False - docker_requires_gpu = True + requires_gpu = 'nvidia-optional' _default_params = { 'detect_threshold': 6, @@ -89,6 +89,10 @@ def get_sorter_version(cls): return 'unknown' else: return 'git-' + commit + + @classmethod + def use_gpu(cls, params): + return params["useGPU"] @classmethod def set_kilosort_path(cls, kilosort_path: str): diff --git a/spikeinterface/sorters/kilosort2/kilosort2.py b/spikeinterface/sorters/kilosort2/kilosort2.py index 646adadbb0..02a2867777 100644 --- a/spikeinterface/sorters/kilosort2/kilosort2.py +++ b/spikeinterface/sorters/kilosort2/kilosort2.py @@ -30,7 +30,6 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): compiled_name: str = 'ks2_compiled' kilosort2_path: Union[str, None] = os.getenv('KILOSORT2_PATH', None) requires_locations = False - docker_requires_gpu = True _default_params = { 'detect_threshold': 6, diff --git a/spikeinterface/sorters/kilosort2_5/kilosort2_5.py b/spikeinterface/sorters/kilosort2_5/kilosort2_5.py index f7de101a3a..0a4f93f8f9 100644 --- a/spikeinterface/sorters/kilosort2_5/kilosort2_5.py +++ b/spikeinterface/sorters/kilosort2_5/kilosort2_5.py @@ -33,7 +33,6 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): compiled_name: str = 'ks2_5_compiled' kilosort2_5_path: Union[str, None] = os.getenv('KILOSORT2_5_PATH', None) requires_locations = False - docker_requires_gpu = True _default_params = { 'detect_threshold': 6, diff --git a/spikeinterface/sorters/kilosort3/kilosort3.py b/spikeinterface/sorters/kilosort3/kilosort3.py index 308e95fd07..047b61d9f6 100644 --- a/spikeinterface/sorters/kilosort3/kilosort3.py +++ b/spikeinterface/sorters/kilosort3/kilosort3.py @@ -31,7 +31,6 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): compiled_name: str = 'ks3_compiled' kilosort3_path: Union[str, None] = os.getenv('KILOSORT3_PATH', None) requires_locations = False - docker_requires_gpu = True _default_params = { 'detect_threshold': 6, diff --git a/spikeinterface/sorters/kilosortbase.py b/spikeinterface/sorters/kilosortbase.py index 399bfdd18f..70931f88ca 100644 --- a/spikeinterface/sorters/kilosortbase.py +++ b/spikeinterface/sorters/kilosortbase.py @@ -18,6 +18,7 @@ class KilosortBase: * _run_from_folder * _get_result_from_folder """ + gpu_capability = 'nvidia-required' @staticmethod def _generate_channel_map_file(recording, output_folder): diff --git a/spikeinterface/sorters/pykilosort/pykilosort.py b/spikeinterface/sorters/pykilosort/pykilosort.py index 6c37d3e5c1..082f705755 100644 --- a/spikeinterface/sorters/pykilosort/pykilosort.py +++ b/spikeinterface/sorters/pykilosort/pykilosort.py @@ -19,7 +19,7 @@ class PyKilosortSorter(BaseSorter): sorter_name = 'pykilosort' requires_locations = False - docker_requires_gpu = True + gpu_capability = 'nvidia-required' compatible_with_parallel = {'loky': True, 'multiprocessing': False, 'threading': False} _default_params = { diff --git a/spikeinterface/sorters/runsorter.py b/spikeinterface/sorters/runsorter.py index 488368d077..d147580b00 100644 --- a/spikeinterface/sorters/runsorter.py +++ b/spikeinterface/sorters/runsorter.py @@ -9,7 +9,7 @@ from ..version import version as si_version from spikeinterface.core.core_tools import check_json, recursive_path_modifier, is_dict_extractor from .sorterlist import sorter_dict -from .utils import SpikeSortingError +from .utils import SpikeSortingError, has_nvidia _common_param_doc = """ @@ -160,12 +160,14 @@ class ContainerClient: def __init__(self, mode, container_image, volumes, extra_kwargs): assert mode in ('docker', 'singularity') self.mode = mode + container_requires_gpu = extra_kwargs.get( + 'container_requires_gpu', None) if mode == 'docker': import docker client = docker.from_env() - if extra_kwargs.get('requires_gpu', False): - extra_kwargs.pop('requires_gpu') + if container_requires_gpu is not None: + extra_kwargs.pop('container_requires_gpu') extra_kwargs["device_requests"] = [ docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])] @@ -198,7 +200,7 @@ def __init__(self, mode, container_image, volumes, extra_kwargs): options=['--bind', singularity_bind] # gpu options - if extra_kwargs.get('requires_gpu', False): + if container_requires_gpu: # only nvidia at the moment options += ['--nv'] @@ -306,8 +308,23 @@ def run_sorter_container(sorter_name, recording, mode, container_image, output_f install_si_from_source = False extra_kwargs = {} - if SorterClass.docker_requires_gpu: - extra_kwargs['requires_gpu'] = True + use_gpu = SorterClass.use_gpu(sorter_params) + gpu_capability = SorterClass.gpu_capability + + if use_gpu: + if gpu_capability == 'nvidia-required': + assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" + extra_kwargs['container_requires_gpu'] = True + elif gpu_capability == 'nvidia-optional': + if has_nvidia(): + extra_kwargs['container_requires_gpu'] = True + else: + if verbose: + print(f"{SorterClass.sorter_name} supports GPU, but no GPU is available.\n" + f"Running the sorter without GPU") + else: + # TODO: make opencl machanism + raise NotImplementedError("Only nvidia support is available") container_client = ContainerClient(mode, container_image, volumes, extra_kwargs) if verbose: diff --git a/spikeinterface/sorters/utils/__init__.py b/spikeinterface/sorters/utils/__init__.py index 597bf057a2..1d45851b5b 100644 --- a/spikeinterface/sorters/utils/__init__.py +++ b/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,2 @@ from .shellscript import ShellScript -from .misc import (SpikeSortingError, get_git_commit) +from .misc import (SpikeSortingError, get_git_commit, has_nvidia) diff --git a/spikeinterface/sorters/utils/misc.py b/spikeinterface/sorters/utils/misc.py index ca1ae7a298..8b326eee61 100644 --- a/spikeinterface/sorters/utils/misc.py +++ b/spikeinterface/sorters/utils/misc.py @@ -20,3 +20,14 @@ def get_git_commit(git_folder, shorten=True): except: commit = None return commit + + +def has_nvidia(): + """ + Checks if the machine has nvidia capability. + """ + try: + check_output('nvidia-smi') + return True + except Exception: # this command not being found can raise quite a few different errors depending on the configuration + return False diff --git a/spikeinterface/sorters/yass/yass.py b/spikeinterface/sorters/yass/yass.py index b3806f1981..fb866f4f76 100644 --- a/spikeinterface/sorters/yass/yass.py +++ b/spikeinterface/sorters/yass/yass.py @@ -17,7 +17,7 @@ class YassSorter(BaseSorter): sorter_name = 'yass' requires_locations = False - docker_requires_gpu = True + gpu_capability = 'nvidia-required' # #################################################