Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .github/workflows/medcat-v2_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,3 @@ jobs:
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: medcat-v2/dist

- name: Create tag for medcat-scripts
run: |
git tag medcat-scritps/v${{ needs.build.outputs.version_only }}
git push origin medcat-scritps/v${{ needs.build.outputs.version_only }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

30 changes: 16 additions & 14 deletions medcat-v2/medcat/__main__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import sys
from medcat.utils.download_scripts import main as __download_scripts


_DL_SCRIPTS_USAGE = (
"Usage: python -m medcat download-scripts [DEST] [log_level]")
_COMMANDS = {
"download-scripts": __download_scripts
}


def _get_usage() -> str:
header = "Available commands:\n"
base = "python -m medcat "
args = " --help"
commands = [base + cmd_name + args
for cmd_name in _COMMANDS]
return header + "\n".join(commands)


def main(*args: str):
if not args:
print(_DL_SCRIPTS_USAGE, file=sys.stderr)
sys.exit(1)
if len(args) >= 1 and args[0] == "download-scripts":
from medcat.utils.download_scripts import main
dest = args[1] if len(args) > 1 else "."
kwargs = {}
if len(args) > 2:
kwargs["log_level"] = args[2].upper()
main(dest, **kwargs)
else:
print(_DL_SCRIPTS_USAGE, file=sys.stderr)
if not args or args[0] not in _COMMANDS:
print(_get_usage(), file=sys.stderr)
sys.exit(1)
_COMMANDS[args[0]](*args[1:])


if __name__ == "__main__":
Expand Down
113 changes: 85 additions & 28 deletions medcat-v2/medcat/utils/download_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
It will link the current setup (i.e medcat version) into account and
subsequently identify and download the medcat-scripts based on the most
recent applicable tag. So if you've got medcat==2.2.0, it might grab
medcat-scripts/v2.2.3 for instance.
medcat/v2.2.3 for instance.
"""
import importlib.metadata
import tempfile
import zipfile
from pathlib import Path
import requests
import logging
import argparse


logger = logging.getLogger(__name__)


EXPECTED_TAG_PREFIX = 'medcat/v'
GITHUB_REPO = "CogStack/cogstack-nlp"
SCRIPTS_PATH = "medcat-scripts/"
DOWNLOAD_URL_TEMPLATE = (
Expand All @@ -27,7 +29,9 @@ def _get_medcat_version() -> str:
"""Return the installed MedCAT version as 'major.minor'."""
version = importlib.metadata.version("medcat")
major, minor, *_ = version.split(".")
return f"{major}.{minor}"
minor_version = f"{major}.{minor}"
logger.debug("Using medcat minor version of %s", minor_version)
return minor_version


def _find_latest_scripts_tag(major_minor: str) -> str:
Expand All @@ -38,9 +42,10 @@ def _find_latest_scripts_tag(major_minor: str) -> str:
matching = [
t["name"]
for t in tags
if t["name"].startswith(f"medcat-scripts/v{major_minor}.")
or t["name"].startswith(f"v{major_minor}.")
if t["name"].startswith(f"{EXPECTED_TAG_PREFIX}{major_minor}.")
]
logger.debug("Found %d matching (out of a total of %d): %s",
len(matching), len(tags), matching)
if not matching:
raise RuntimeError(
f"No medcat-scripts tags found for MedCAT {major_minor}.x")
Expand All @@ -49,36 +54,42 @@ def _find_latest_scripts_tag(major_minor: str) -> str:
return matching[0]


def fetch_scripts(destination: str | Path = ".") -> Path:
"""Download the latest compatible medcat-scripts folder into.

Args:
destination (str | Path): The destination path. Defaults to ".".
def _determine_url(overwrite_url: str | None,
overwrite_tag: str | None) -> str:
if overwrite_url:
logger.info("Using the overwrite URL instead: %s", overwrite_url)
zip_url = overwrite_url
else:
version = _get_medcat_version()
if overwrite_tag:
tag = overwrite_tag
logger.info("Using overwritten tag '%s'", tag)
else:
tag = _find_latest_scripts_tag(version)

Returns:
Path: The path of the scripts.
"""
dest = Path(destination).expanduser().resolve()
dest.mkdir(parents=True, exist_ok=True)
logger.info("Fetching scripts for MedCAT %s → tag %s}",
version, tag)

version = _get_medcat_version()
tag = _find_latest_scripts_tag(version)
# Download the GitHub auto-generated zipball
zip_url = DOWNLOAD_URL_TEMPLATE.format(tag=tag)
return zip_url

logger.info("Fetching scripts for MedCAT %s → tag %s}",
version, tag)

# Download the GitHub auto-generated zipball
zip_url = DOWNLOAD_URL_TEMPLATE.format(tag=tag)
def _download_zip(zip_url: str, tmp: tempfile._TemporaryFileWrapper):
with requests.get(zip_url, stream=True, timeout=30) as r:
r.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False) as tmp:
for chunk in r.iter_content(chunk_size=8192):
tmp.write(chunk)
zip_path = Path(tmp.name)
for chunk in r.iter_content(chunk_size=8192):
tmp.write(chunk)
tmp.flush()


def _extract_zip(dest: Path, zip_path: Path):
# Extract only medcat-scripts/ from the archive
wrote_files_num = 0
total_files = 0
with zipfile.ZipFile(zip_path) as zf:
for m in zf.namelist():
total_files += 1
if f"/{SCRIPTS_PATH}" not in m:
continue
# skip repo-hash prefix
Expand All @@ -88,14 +99,60 @@ def fetch_scripts(destination: str | Path = ".") -> Path:
else:
with open(target, "wb") as f:
f.write(zf.read(m))

wrote_files_num += 1

logger.debug("Wrote %d / %d files", wrote_files_num, total_files)
if not wrote_files_num:
logger.warning(
"Was unable to extract any files from '%s' folder in the zip. "
"The folder doesn't seem to exist in the provided archive.",
SCRIPTS_PATH)
logger.info("Scripts extracted to: %s", dest)


def fetch_scripts(destination: str | Path = ".",
overwrite_url: str | None = None,
overwrite_tag: str | None = None) -> Path:
"""Download the latest compatible medcat-scripts folder into.

Args:
destination (str | Path): The destination path. Defaults to ".".
overwrite_url (str | None): The overwrite URL. Defaults to None.
overwrite_tag (str | None): The overwrite tag. Defaults to None.

Returns:
Path: The path of the scripts.
"""
dest = Path(destination).expanduser().resolve()
dest.mkdir(parents=True, exist_ok=True)

zip_url = _determine_url(overwrite_url, overwrite_tag)
with tempfile.NamedTemporaryFile() as tmp:
_download_zip(zip_url, tmp)
_extract_zip(dest, Path(tmp.name))
return dest


def main(destination: str = ".",
log_level: int | str = logging.INFO):
def main(*in_args: str):
parser = argparse.ArgumentParser(
prog="python -m medcat download-scripts",
description="Download medcat-scripts"
)
parser.add_argument("destination", type=str, default=".", nargs='?',
help="The destination folder for the scripts")
parser.add_argument("--overwrite-url", type=str, default=None,
help="The URL to download and extract from. "
"This is expected to refer to a .zip file "
"that has a `medcat-scripts` folder.")
parser.add_argument("--overwrite-tag", '-t', type=str, default=None,
help="The tag to use from GitHub")
parser.add_argument("--log-level", type=str, default='INFO',
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="The log level for fetching")
args = parser.parse_args(in_args)
log_level = args.log_level
logger.setLevel(log_level)
if not logger.handlers:
logger.addHandler(logging.StreamHandler())
fetch_scripts(destination)
fetch_scripts(args.destination, args.overwrite_url,
args.overwrite_tag)
Loading