diff --git a/.github/workflows/medcat-v2_release.yml b/.github/workflows/medcat-v2_release.yml index e3cf5ca9..6f1d0dcc 100644 --- a/.github/workflows/medcat-v2_release.yml +++ b/.github/workflows/medcat-v2_release.yml @@ -200,11 +200,3 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: medcat-v2/dist - - - name: Create tag for medcat-scripts - run: | - git tag medcat-scritps/v${{ needs.build.outputs.version_only }} - git push origin medcat-scritps/v${{ needs.build.outputs.version_only }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - diff --git a/medcat-v2/medcat/__main__.py b/medcat-v2/medcat/__main__.py index f032b288..c0fc985b 100644 --- a/medcat-v2/medcat/__main__.py +++ b/medcat-v2/medcat/__main__.py @@ -1,24 +1,26 @@ import sys +from medcat.utils.download_scripts import main as __download_scripts -_DL_SCRIPTS_USAGE = ( - "Usage: python -m medcat download-scripts [DEST] [log_level]") +_COMMANDS = { + "download-scripts": __download_scripts +} + + +def _get_usage() -> str: + header = "Available commands:\n" + base = "python -m medcat " + args = " --help" + commands = [base + cmd_name + args + for cmd_name in _COMMANDS] + return header + "\n".join(commands) def main(*args: str): - if not args: - print(_DL_SCRIPTS_USAGE, file=sys.stderr) - sys.exit(1) - if len(args) >= 1 and args[0] == "download-scripts": - from medcat.utils.download_scripts import main - dest = args[1] if len(args) > 1 else "." - kwargs = {} - if len(args) > 2: - kwargs["log_level"] = args[2].upper() - main(dest, **kwargs) - else: - print(_DL_SCRIPTS_USAGE, file=sys.stderr) + if not args or args[0] not in _COMMANDS: + print(_get_usage(), file=sys.stderr) sys.exit(1) + _COMMANDS[args[0]](*args[1:]) if __name__ == "__main__": diff --git a/medcat-v2/medcat/utils/download_scripts.py b/medcat-v2/medcat/utils/download_scripts.py index bd42c9e5..6b591970 100644 --- a/medcat-v2/medcat/utils/download_scripts.py +++ b/medcat-v2/medcat/utils/download_scripts.py @@ -3,7 +3,7 @@ It will link the current setup (i.e medcat version) into account and subsequently identify and download the medcat-scripts based on the most recent applicable tag. So if you've got medcat==2.2.0, it might grab -medcat-scripts/v2.2.3 for instance. +medcat/v2.2.3 for instance. """ import importlib.metadata import tempfile @@ -11,11 +11,13 @@ from pathlib import Path import requests import logging +import argparse logger = logging.getLogger(__name__) +EXPECTED_TAG_PREFIX = 'medcat/v' GITHUB_REPO = "CogStack/cogstack-nlp" SCRIPTS_PATH = "medcat-scripts/" DOWNLOAD_URL_TEMPLATE = ( @@ -27,7 +29,9 @@ def _get_medcat_version() -> str: """Return the installed MedCAT version as 'major.minor'.""" version = importlib.metadata.version("medcat") major, minor, *_ = version.split(".") - return f"{major}.{minor}" + minor_version = f"{major}.{minor}" + logger.debug("Using medcat minor version of %s", minor_version) + return minor_version def _find_latest_scripts_tag(major_minor: str) -> str: @@ -38,9 +42,10 @@ def _find_latest_scripts_tag(major_minor: str) -> str: matching = [ t["name"] for t in tags - if t["name"].startswith(f"medcat-scripts/v{major_minor}.") - or t["name"].startswith(f"v{major_minor}.") + if t["name"].startswith(f"{EXPECTED_TAG_PREFIX}{major_minor}.") ] + logger.debug("Found %d matching (out of a total of %d): %s", + len(matching), len(tags), matching) if not matching: raise RuntimeError( f"No medcat-scripts tags found for MedCAT {major_minor}.x") @@ -49,36 +54,42 @@ def _find_latest_scripts_tag(major_minor: str) -> str: return matching[0] -def fetch_scripts(destination: str | Path = ".") -> Path: - """Download the latest compatible medcat-scripts folder into. - - Args: - destination (str | Path): The destination path. Defaults to ".". +def _determine_url(overwrite_url: str | None, + overwrite_tag: str | None) -> str: + if overwrite_url: + logger.info("Using the overwrite URL instead: %s", overwrite_url) + zip_url = overwrite_url + else: + version = _get_medcat_version() + if overwrite_tag: + tag = overwrite_tag + logger.info("Using overwritten tag '%s'", tag) + else: + tag = _find_latest_scripts_tag(version) - Returns: - Path: The path of the scripts. - """ - dest = Path(destination).expanduser().resolve() - dest.mkdir(parents=True, exist_ok=True) + logger.info("Fetching scripts for MedCAT %s → tag %s}", + version, tag) - version = _get_medcat_version() - tag = _find_latest_scripts_tag(version) + # Download the GitHub auto-generated zipball + zip_url = DOWNLOAD_URL_TEMPLATE.format(tag=tag) + return zip_url - logger.info("Fetching scripts for MedCAT %s → tag %s}", - version, tag) - # Download the GitHub auto-generated zipball - zip_url = DOWNLOAD_URL_TEMPLATE.format(tag=tag) +def _download_zip(zip_url: str, tmp: tempfile._TemporaryFileWrapper): with requests.get(zip_url, stream=True, timeout=30) as r: r.raise_for_status() - with tempfile.NamedTemporaryFile(delete=False) as tmp: - for chunk in r.iter_content(chunk_size=8192): - tmp.write(chunk) - zip_path = Path(tmp.name) + for chunk in r.iter_content(chunk_size=8192): + tmp.write(chunk) + tmp.flush() + +def _extract_zip(dest: Path, zip_path: Path): # Extract only medcat-scripts/ from the archive + wrote_files_num = 0 + total_files = 0 with zipfile.ZipFile(zip_path) as zf: for m in zf.namelist(): + total_files += 1 if f"/{SCRIPTS_PATH}" not in m: continue # skip repo-hash prefix @@ -88,14 +99,60 @@ def fetch_scripts(destination: str | Path = ".") -> Path: else: with open(target, "wb") as f: f.write(zf.read(m)) - + wrote_files_num += 1 + + logger.debug("Wrote %d / %d files", wrote_files_num, total_files) + if not wrote_files_num: + logger.warning( + "Was unable to extract any files from '%s' folder in the zip. " + "The folder doesn't seem to exist in the provided archive.", + SCRIPTS_PATH) logger.info("Scripts extracted to: %s", dest) + + +def fetch_scripts(destination: str | Path = ".", + overwrite_url: str | None = None, + overwrite_tag: str | None = None) -> Path: + """Download the latest compatible medcat-scripts folder into. + + Args: + destination (str | Path): The destination path. Defaults to ".". + overwrite_url (str | None): The overwrite URL. Defaults to None. + overwrite_tag (str | None): The overwrite tag. Defaults to None. + + Returns: + Path: The path of the scripts. + """ + dest = Path(destination).expanduser().resolve() + dest.mkdir(parents=True, exist_ok=True) + + zip_url = _determine_url(overwrite_url, overwrite_tag) + with tempfile.NamedTemporaryFile() as tmp: + _download_zip(zip_url, tmp) + _extract_zip(dest, Path(tmp.name)) return dest -def main(destination: str = ".", - log_level: int | str = logging.INFO): +def main(*in_args: str): + parser = argparse.ArgumentParser( + prog="python -m medcat download-scripts", + description="Download medcat-scripts" + ) + parser.add_argument("destination", type=str, default=".", nargs='?', + help="The destination folder for the scripts") + parser.add_argument("--overwrite-url", type=str, default=None, + help="The URL to download and extract from. " + "This is expected to refer to a .zip file " + "that has a `medcat-scripts` folder.") + parser.add_argument("--overwrite-tag", '-t', type=str, default=None, + help="The tag to use from GitHub") + parser.add_argument("--log-level", type=str, default='INFO', + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="The log level for fetching") + args = parser.parse_args(in_args) + log_level = args.log_level logger.setLevel(log_level) if not logger.handlers: logger.addHandler(logging.StreamHandler()) - fetch_scripts(destination) + fetch_scripts(args.destination, args.overwrite_url, + args.overwrite_tag)