diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index e41d44f..15cda73 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -31,4 +31,8 @@ jobs: pip install .[test] - name: Pytest run: | - pytest -v + pytest -m "not library" -v + - name: Pytest library tests + if: ${{ matrix.os == 'ubuntu-latest' }} + run: | + pytest -m library -v diff --git a/doc/library.rst b/doc/library.rst index 091cde2..35f6f7c 100644 --- a/doc/library.rst +++ b/doc/library.rst @@ -1,27 +1,58 @@ Probeinterface public library ============================= -Probeinterface also handles a collection of probe descriptions on the -`GitHub platform `_ +Probeinterface also handles a collection of probe descriptions in the +`ProbeInterface library `_ -The python module has a simple function to download and cache locally by using `get_probe(...)` :: +The python module has a simple function to download and cache locally by using ``get_probe(...)``: + + +.. code-block:: python from probeinterface import get_probe - probe = get_probe(manufacturer='neuronexus', - probe_name='A1x32-Poly3-10mm-50-177') + probe = get_probe( + manufacturer='neuronexus', + probe_name='A1x32-Poly3-10mm-50-177' + ) + + +Once a probe is downloaded, it is cached locally for future use. + +There are several helper functions to explore the library: + +.. code-block:: python + + from probeinterface.library import ( + list_manufacturers, + list_probes_by_manufacturer, + list_all_probes + ) + + # List all manufacturers + manufacturers = list_manufacturers() + + # List all probes for a given manufacturer + probes = list_probes_by_manufacturer('neuronexus') + + # List all probes in the library + all_probes = list_all_probes() + + # Cache all probes locally + cache_full_library() + +Each function has an optional ``tag`` argument to specify a git tag/branch/commit to get a specific version of the library. -We expect to build rapidly commonly used probes in this public repository. -How to contribute ------------------ +How to contribute to the library +-------------------------------- -TODO: explain with more details +Each probe in the library is represented by a JSON file and an image. +To contribute a new probe to the library, follow these steps: - 1. Generate the JSON file with probeinterface (or directly - with another language) + 1. Generate the JSON file with probeinterface (or directly with another language) 2. Generate an image of the probe with the `plot_probe` function in probeinterface 3. Clone the `probeinterface_library repo `_ - 4. Put the JSON file and image into the correct folder or make a new folder (following the format of the repo) + 4. Put the JSON file and image into the correct folder: ``probeinterface_library///``` 5. Push to one of your branches with a git client 6. Make a pull request to the main repo diff --git a/pyproject.toml b/pyproject.toml index 46040db..1662d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "numpy", "packaging", + "requests" ] [project.urls] @@ -57,6 +58,11 @@ docs = [ "pandas", ] +[tool.pytest.ini_options] +markers = [ + "library", +] + [tool.coverage.run] omit = [ "tests/*", diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 4f8d746..3317c79 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -39,5 +39,13 @@ generate_multi_columns_probe, generate_multi_shank, ) -from .library import get_probe +from .library import ( + get_probe, + list_manufacturers, + list_probes_by_manufacturer, + list_all_probes, + get_tags_in_library, + cache_full_library, + clear_cache, +) from .wiring import get_available_pathways diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index 7666cd8..505ead0 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -11,23 +11,33 @@ from __future__ import annotations import os +import warnings from pathlib import Path from urllib.request import urlopen +import requests from typing import Optional from .io import read_probeinterface # OLD URL on gin # public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/" - # Now on github since 2023/06/15 -public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/" +public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/" + # check this for windows and osx -cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" +def get_cache_folder() -> Path: + """Get the cache folder for probeinterface library files. + Returns + ------- + cache_folder : Path + The path to the cache folder. + """ + return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" -def download_probeinterface_file(manufacturer: str, probe_name: str): + +def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None: """Download the probeinterface file to the cache directory. Note that the file is itself a ProbeGroup but on the repo each file represents one probe. @@ -38,16 +48,24 @@ def download_probeinterface_file(manufacturer: str, probe_name: str): The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe """ - os.makedirs(cache_folder / manufacturer, exist_ok=True) - localfile = cache_folder / manufacturer / (probe_name + ".json") - distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json" - dist = urlopen(distantfile) - with open(localfile, "wb") as f: - f.write(dist.read()) + cache_folder = get_cache_folder() + if tag is not None: + assert tag in get_tags_in_library(), f"Tag {tag} not found in library" + else: + tag = "main" + os.makedirs(cache_folder / tag / manufacturer, exist_ok=True) + local_file = cache_folder / tag / manufacturer / (probe_name + ".json") + remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json" + rem = urlopen(remote_file) + with open(local_file, "wb") as f: + f.write(rem.read()) -def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: + +def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]: """ Get Probe from local cache @@ -57,24 +75,40 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe Returns ------- probe : Probe object, or None if no probeinterface JSON file is found """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" - localfile = cache_folder / manufacturer / (probe_name + ".json") - if not localfile.is_file(): + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if not local_file.is_file(): return None else: - probegroup = read_probeinterface(localfile) + probegroup = read_probeinterface(local_file) probe = probegroup.probes[0] probe._probe_group = None return probe -def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe": +def get_probe( + manufacturer: str, + probe_name: str, + name: Optional[str] = None, + tag: Optional[str] = None, + force_download: bool = False, +) -> "Probe": """ Get probe from ProbeInterface library @@ -86,21 +120,173 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> The probe name name : str | None, default: None Optional name for the probe + tag : str | None, default: None + Optional tag for the probe + force_download : bool, default: False + If True, force re-download of the probe file. Returns ---------- probe : Probe object """ - - probe = get_from_cache(manufacturer, probe_name) + if not force_download: + probe = get_from_cache(manufacturer, probe_name, tag=tag) + else: + probe = None if probe is None: - download_probeinterface_file(manufacturer, probe_name) - probe = get_from_cache(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) if probe.manufacturer == "": probe.manufacturer = manufacturer if name is not None: probe.name = name return probe + + +def cache_full_library(tag=None) -> None: # pragma: no cover + """ + Download all probes from the library to the cache directory. + """ + manufacturers = list_manufacturers(tag=tag) + + for manufacturer in manufacturers: + probes = list_probes_by_manufacturer(manufacturer, tag=tag) + for probe_name in probes: + try: + download_probeinterface_file(manufacturer, probe_name, tag=tag) + except Exception as e: + warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}") + + +def clear_cache(tag=None) -> None: # pragma: no cover + """ + Clear the cache folder for probeinterface library files. + + Parameters + ---------- + tag : str | None, default: None + Optional tag for the probe + """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if cache_folder_tag.is_dir(): + import shutil + + shutil.rmtree(cache_folder_tag) + else: + import shutil + + shutil.rmtree(cache_folder) + + +def list_manufacturers(tag=None) -> list[str]: + """ + Get the list of available manufacturers in the library + + Returns + ------- + manufacturers : list of str + List of available manufacturers + """ + if tag is not None: + assert ( + tag in get_tags_in_library() + ), f"Tag {tag} not found in library. Available tags are {get_tags_in_library()}." + return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag) + + +def list_probes_by_manufacturer(manufacturer: str, tag=None) -> list[str]: + """ + Get the list of available probes for a given manufacturer + + Parameters + ---------- + manufacturer : str + The probe manufacturer + + Returns + ------- + probes : list of str + List of available probes for the given manufacturer + """ + if tag is not None: + assert ( + tag in get_tags_in_library() + ), f"Tag {tag} not found in library. Available tags are {get_tags_in_library()}." + assert manufacturer in list_manufacturers( + tag=tag + ), f"Manufacturer {manufacturer} not found in library. Available manufacturers are {list_manufacturers(tag=tag)}." + return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag) + + +def list_all_probes(tag=None) -> dict[str, list[str]]: + """ + Get the list of all available probes in the library + + Returns + ------- + all_probes : dict + Dictionary with manufacturers as keys and list of probes as values + """ + all_probes = {} + manufacturers = list_manufacturers(tag=tag) + for manufacturer in manufacturers: + probes = list_probes_by_manufacturer(manufacturer, tag=tag) + all_probes[manufacturer] = probes + return all_probes + + +def get_tags_in_library() -> list[str]: + """ + Get the list of available tags in the library + + Returns + ------- + tags : list of str + List of available tags + """ + tags = get_all_tags("SpikeInterface", "probeinterface_library") + return tags + + +### UTILS +def get_all_tags(owner: str, repo: str, token: str = None): + """ + Get all tags for a repo. + Returns a list of tag names, or an empty list if no tags exist. + """ + url = f"https://api.github.com/repos/{owner}/{repo}/tags" + headers = {} + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") + tags = resp.json() + return [tag["name"] for tag in tags] + + +def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None): + """ + Return a list of directory names in the given repo at the specified path. + You can pass a branch, tag, or commit SHA via `ref`. + If token is provided, use it for authenticated requests (higher rate limits). + """ + url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + params = {} + if ref: + params["ref"] = ref + headers = {} + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers, params=params) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}") + items = resp.json() + return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."] diff --git a/tests/test_library.py b/tests/test_library.py index 8d4059d..2134ff2 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -1,19 +1,54 @@ +import os +import pytest + from probeinterface import Probe -from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe +from probeinterface.library import ( + download_probeinterface_file, + get_from_cache, + get_probe, + get_tags_in_library, + list_manufacturers, + list_probes_by_manufacturer, + list_all_probes, + get_cache_folder, + cache_full_library, + clear_cache, +) -from pathlib import Path -import numpy as np +manufacturer = "neuronexus" +probe_name = "A1x32-Poly3-10mm-50-177" -import pytest +def _remove_from_cache(manufacturer: str, probe_name: str, tag=None) -> None: + """ + Remove Probe from local cache -manufacturer = "neuronexus" -probe_name = "A1x32-Poly3-10mm-50-177" + Parameters + ---------- + manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps" + The probe manufacturer + probe_name : str (see probeinterface_libary for options) + The probe name + tag : str | None, default: None + Optional tag for the probe + """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" + + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if local_file.is_file(): + os.remove(local_file) def test_download_probeinterface_file(): - download_probeinterface_file(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=None) def test_get_from_cache(): @@ -21,6 +56,14 @@ def test_get_from_cache(): probe = get_from_cache(manufacturer, probe_name) assert isinstance(probe, Probe) + tag = get_tags_in_library()[0] + probe = get_from_cache(manufacturer, probe_name, tag=tag) + assert probe is None # because we did not download with this tag + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) + _remove_from_cache(manufacturer, probe_name, tag=tag) + assert isinstance(probe, Probe) + probe = get_from_cache("yep", "yop") assert probe is None @@ -31,7 +74,54 @@ def test_get_probe(): assert probe.get_contact_count() == 32 +def test_available_tags(): + tags = get_tags_in_library() + if len(tags) > 0: + for tag in tags: + assert isinstance(tag, str) + assert len(tag) > 0 + + +@pytest.mark.library +def test_list_manufacturers(): + manufacturers = list_manufacturers() + assert isinstance(manufacturers, list) + assert "neuronexus" in manufacturers + assert "imec" in manufacturers + + +@pytest.mark.library +def test_list_probes(): + manufacturers = list_all_probes() + for manufacturer in manufacturers: + probes = list_probes_by_manufacturer(manufacturer) + assert isinstance(probes, list) + assert len(probes) > 0 + + +@pytest.mark.skip(reason="long test that downloads the full library") +def test_cache_full_library(): + tag = get_tags_in_library()[0] if len(get_tags_in_library()) > 0 else None + print(tag) + cache_full_library(tag=tag) + all_probes = list_all_probes(tag=tag) + # spot check that a known probe is in the cache + for manufacturer, probes in all_probes.items(): + for probe_name in probes: + probe = get_from_cache(manufacturer, probe_name, tag=tag) + assert isinstance(probe, Probe) + + clear_cache(tag=tag) + for manufacturer, probes in all_probes.items(): + for probe_name in probes: + probe = get_from_cache(manufacturer, probe_name, tag=tag) + assert probe is None + + if __name__ == "__main__": test_download_probeinterface_file() test_get_from_cache() test_get_probe() + test_list_manufacturers() + test_list_probes() + test_cache_full_library() diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2867804..f19c0d7 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -39,7 +39,8 @@ def test_plot_probegroup(): plot_probegroup(probegroup, same_axes=False) # remove when plot_probe_group has been removed - plot_probe_group(probegroup) + with pytest.warns(DeprecationWarning): + plot_probe_group(probegroup) # 3d probegroup_3d = ProbeGroup()