From 9a6a80a8d9d7a21ee8252de159ee423d66300220 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 8 Apr 2022 23:49:15 +0800 Subject: [PATCH 01/20] draft download Signed-off-by: Yiheng Wang --- docs/source/bundle.rst | 1 + monai/bundle/__init__.py | 2 +- monai/bundle/__main__.py | 2 +- monai/bundle/scripts.py | 73 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 297409cd7e..023effaad1 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -36,6 +36,7 @@ Model Bundle `Scripts` --------- .. autofunction:: ckpt_export +.. autofunction:: download .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index d6a452b5a4..24995cb971 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index d77b396e79..3e3534ef74 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 64172c4541..224c684920 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,6 +11,7 @@ import ast import json +import os import pprint import re from logging.config import fileConfig @@ -19,7 +20,7 @@ import torch from torch.cuda import is_available -from monai.apps.utils import download_url, get_logger +from monai.apps.utils import download_url, extractall, get_logger from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike from monai.data import save_net_with_metadata @@ -29,6 +30,7 @@ validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +requests_get, has_requests = optional_import("requests", name="get") logger = get_logger(module_name=__name__) @@ -116,6 +118,75 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int return tuple(ret) +def _get_git_release_url( + repo_owner: str, + repo_name: str, + tag_name: str, + filename: Optional[str] = None, +): + if filename is not None: + return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" + else: + raise NotImplementedError("download the whole package is not implemented so far.") + + +def download( + repo: str, + package: str, + filename: Optional[str] = None, + source: str = "github", + download_path: str = "download", + hash_val: Optional[str] = None, + hash_type: str = "md5", + extract: bool = False, + has_base: bool = True, + version: int = 1, + # args_file +): + """ + download the bundle package or a file that belongs to the package from the specified source. + This function refers to: + https://pytorch.org/docs/stable/_modules/torch/hub.html + + Args: + repo: the bundle package. The format depends on the source. If the source is `github`, + it should be in the form of `repo_owner/repo_name`. If the source is `ngc`, it should + be in the form of `org/team`. + package: the bundle package name. If the source is `github`, it should be the same as the + release tag. + filename: the filename of the bundle package that needs to be downloaded. It is an optional + argument and if not specified, the whole bundle package will be downloaded. + source: the place that saved the bundle package. So far, only `github` and `ngc` are supported. + For the `github` source, the bundle package should be within the releases. + download_path: target filepath to save the downloaded file (including the filename). + If undefined, `os.path.basename(url)` will be used. + hash_val: expected hash value to validate the downloaded file. + if None, skip hash validation. + hash_type: 'md5' or 'sha1'. + extract: whether to extract the downloaded file. + output_dir: target directory to save extracted files. + has_base: whether the extracted files have a base folder. This flag is used when checking if the existing + folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped + to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should + be False. + version: this argument only works on ngc souce, and it represents the version of the model. + + """ + if source == "github": + if len(repo.split("/")) != 2: + raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") + repo_owner, repo_name = repo.split("/") + url = _get_git_release_url(repo_owner, repo_name, tag_name=package, filename=filename) + filepath = os.path.join(download_path, filename) + elif source == "ngc": + # to be modified + url = f"https://api.ngc.nvidia.com/v2/models/{repo}/{package}/versions/{version}/zip" + filepath = os.path.join(download_path, f"{package}.zip") + download_url(url=url, filepath=filepath, hash_val=hash_val, hash_type=hash_type) + if extract is True: + extractall(filepath=filepath, output_dir=download_path, has_base=has_base) + + def run( runner_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, From 3b6d2cd6299616ea1694a45ab3327faa74a92c33 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 12 Apr 2022 23:07:16 +0800 Subject: [PATCH 02/20] update bundle download Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 141 ++++++++++++++++++++++++++++------------ 1 file changed, 99 insertions(+), 42 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 224c684920..900754184f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,10 +11,10 @@ import ast import json -import os import pprint import re from logging.config import fileConfig +from pathlib import Path from typing import Dict, Optional, Sequence, Tuple, Union import torch @@ -122,26 +122,63 @@ def _get_git_release_url( repo_owner: str, repo_name: str, tag_name: str, + filename: str, +): + return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" + + +def _get_git_release_assets( + repo_owner: str, + repo_name: str, + tag_name: str, +): + url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag_name}" + resp = requests_get(url) + resp.raise_for_status() + assets_list = json.loads(resp.text)["assets"] + assets_info = {} + for asset in assets_list: + assets_info[asset["name"]] = asset["browser_download_url"] + return assets_info + + +def _download_from_github( + repo: str, + tag_name: str, + bundle_dir: Path, filename: Optional[str] = None, + progress: bool = True, + extract: bool = False, ): - if filename is not None: - return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" + if len(repo.split("/")) != 2: + raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") + repo_owner, repo_name = repo.split("/") + if filename is None: + # download the whole bundle package if filename is not provided + assets_info = _get_git_release_assets(repo_owner, repo_name, tag_name=tag_name) + for name, url in assets_info.items(): + download_url(url=url, filepath=bundle_dir / f"{name}", hash_val=None, progress=progress) + if extract is True: + logger.info("When download a whole bundle package, extract is not supported, skip extracting.") + logger.info(f"All files within the bundle package {tag_name} are downloaded.") else: - raise NotImplementedError("download the whole package is not implemented so far.") + # download a single file + url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) + filepath = bundle_dir / f"{filename}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + if extract is True: + extractall(filepath=filepath, output_dir=bundle_dir, has_base=False) def download( - repo: str, - package: str, + repo: Optional[str] = None, + package: Optional[str] = None, + bundle_dir: Optional[PathLike] = None, filename: Optional[str] = None, source: str = "github", - download_path: str = "download", - hash_val: Optional[str] = None, - hash_type: str = "md5", + progress: bool = True, extract: bool = False, - has_base: bool = True, - version: int = 1, - # args_file + args_file: Optional[str] = None, ): """ download the bundle package or a file that belongs to the package from the specified source. @@ -149,42 +186,62 @@ def download( https://pytorch.org/docs/stable/_modules/torch/hub.html Args: - repo: the bundle package. The format depends on the source. If the source is `github`, - it should be in the form of `repo_owner/repo_name`. If the source is `ngc`, it should - be in the form of `org/team`. - package: the bundle package name. If the source is `github`, it should be the same as the - release tag. + repo: the repo name. The format depends on the source. If `None`, must be provided in `args_file`. + If the source is `github`, it should be in the form of `repo_owner/repo_name`. + For example: `Project-MONAI/MONAI`. + package: the bundle package name. If `None`, must be provided in `args_file`. + If the source is `github`, it should be the same as the release tag. + bundle_dir: target directory to store the download data. + Default is `bundle` subfolder under `torch.hub get_dir()`. + If undefined, `os.path.basename(url)` will be used. filename: the filename of the bundle package that needs to be downloaded. It is an optional argument and if not specified, the whole bundle package will be downloaded. source: the place that saved the bundle package. So far, only `github` and `ngc` are supported. For the `github` source, the bundle package should be within the releases. - download_path: target filepath to save the downloaded file (including the filename). - If undefined, `os.path.basename(url)` will be used. - hash_val: expected hash value to validate the downloaded file. - if None, skip hash validation. - hash_type: 'md5' or 'sha1'. - extract: whether to extract the downloaded file. - output_dir: target directory to save extracted files. - has_base: whether the extracted files have a base folder. This flag is used when checking if the existing - folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped - to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should - be False. - version: this argument only works on ngc souce, and it represents the version of the model. + progress: whether to display a progress bar. + extract: whether to extract the downloaded file. This argument only works when download a single file. + args_file: a JSON or YAML file to provide default values for all the args in this function. + so that the command line inputs can be simplified. """ - if source == "github": - if len(repo.split("/")) != 2: - raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") - repo_owner, repo_name = repo.split("/") - url = _get_git_release_url(repo_owner, repo_name, tag_name=package, filename=filename) - filepath = os.path.join(download_path, filename) - elif source == "ngc": - # to be modified - url = f"https://api.ngc.nvidia.com/v2/models/{repo}/{package}/versions/{version}/zip" - filepath = os.path.join(download_path, f"{package}.zip") - download_url(url=url, filepath=filepath, hash_val=hash_val, hash_type=hash_type) - if extract is True: - extractall(filepath=filepath, output_dir=download_path, has_base=has_base) + _args = _update_args( + args=args_file, + repo=repo, + package=package, + bundle_dir=bundle_dir, + filename=filename, + source=source, + progress=progress, + extract=extract, + ) + if "repo" not in _args: + raise ValueError(f"`repo` is required for 'monai.bundle download'.\n{download.__doc__}") + if "package" not in _args: + raise ValueError(f"`package` is required for 'monai.bundle download'.\n{download.__doc__}") + _log_input_summary(tag="download", args=_args) + repo_, package_, bundle_dir_, filename_, source_, progress_, extract_ = _pop_args( + _args, "repo", "package", bundle_dir=None, filename=None, source="github", progress=True, extract=False + ) + + if not bundle_dir_: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir_ = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + bundle_dir_ = Path(bundle_dir_) + + if source_ == "github": + _download_from_github( + repo=repo_, + tag_name=package_, + bundle_dir=bundle_dir_, + filename=filename_, + progress=progress_, + extract=extract_, + ) + else: + raise NotImplementedError("So far, only `github` source is supported.") def run( From 899dbd65cd6edcfe5e4843e52d1a49035a2cf2c9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 14 Apr 2022 19:27:40 +0800 Subject: [PATCH 03/20] add url and load Signed-off-by: Yiheng Wang --- monai/bundle/__init__.py | 2 +- monai/bundle/scripts.py | 137 +++++++++++++++++++++++++++++++-------- 2 files changed, 111 insertions(+), 28 deletions(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 24995cb971..f30ee9c40c 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, download, load, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 900754184f..980bae2f19 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,6 +11,7 @@ import ast import json +import os import pprint import re from logging.config import fileConfig @@ -20,6 +21,8 @@ import torch from torch.cuda import is_available +import warnings + from monai.apps.utils import download_url, extractall, get_logger from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike @@ -145,7 +148,7 @@ def _get_git_release_assets( def _download_from_github( repo: str, tag_name: str, - bundle_dir: Path, + download_path: Path, filename: Optional[str] = None, progress: bool = True, extract: bool = False, @@ -157,47 +160,62 @@ def _download_from_github( # download the whole bundle package if filename is not provided assets_info = _get_git_release_assets(repo_owner, repo_name, tag_name=tag_name) for name, url in assets_info.items(): - download_url(url=url, filepath=bundle_dir / f"{name}", hash_val=None, progress=progress) + download_url(url=url, filepath=download_path / f"{name}", hash_val=None, progress=progress) if extract is True: logger.info("When download a whole bundle package, extract is not supported, skip extracting.") logger.info(f"All files within the bundle package {tag_name} are downloaded.") else: # download a single file url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) - filepath = bundle_dir / f"{filename}" + filepath = download_path / f"{filename}" download_url(url=url, filepath=filepath, hash_val=None, progress=progress) if extract is True: - extractall(filepath=filepath, output_dir=bundle_dir, has_base=False) + extractall(filepath=filepath, output_dir=download_path, has_base=True) def download( - repo: Optional[str] = None, package: Optional[str] = None, bundle_dir: Optional[PathLike] = None, filename: Optional[str] = None, source: str = "github", + repo: Optional[str] = None, + url: Optional[str] = None, progress: bool = True, extract: bool = False, args_file: Optional[str] = None, ): """ - download the bundle package or a file that belongs to the package from the specified source. + download the bundle package or a file that belongs to the package from the specified source or url. This function refers to: https://pytorch.org/docs/stable/_modules/torch/hub.html + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry, and download the whole bundle package: + python -m monai.bundle download --package --source "github" --repo + + # Execute this module as a CLI entry, and download a single file: + python -m monai.bundle download --package --filename --repo + + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. + # Other args still can override the default args at runtime: + python -m monai.bundle download --args_file "/workspace/data/args.json" --filename + Args: - repo: the repo name. The format depends on the source. If `None`, must be provided in `args_file`. - If the source is `github`, it should be in the form of `repo_owner/repo_name`. + package: the bundle package name. If `None` and `url` is `None`, it must be provided in `args_file`. + If `source` is `github`, it should be the same as the release tag. + bundle_dir: target directory to store the download data. Default is `bundle` subfolder under`torch.hub get_dir()`. + filename: the filename that needs to be downloaded. + If `source` is `github` and filename is `None`, the whole bundle package will be downloaded. + source: the place that saved the bundle package. + If `source` is `github`, the bundle package should be within the releases. + repo: the repo name. If `None` and `url` is `None`, it must be provided in `args_file`. + If `source` is `github`, it should be in the form of `repo_owner/repo_name`. For example: `Project-MONAI/MONAI`. - package: the bundle package name. If `None`, must be provided in `args_file`. - If the source is `github`, it should be the same as the release tag. - bundle_dir: target directory to store the download data. - Default is `bundle` subfolder under `torch.hub get_dir()`. - If undefined, `os.path.basename(url)` will be used. - filename: the filename of the bundle package that needs to be downloaded. It is an optional - argument and if not specified, the whole bundle package will be downloaded. - source: the place that saved the bundle package. So far, only `github` and `ngc` are supported. - For the `github` source, the bundle package should be within the releases. + url: the url to download the data. It is an optional argument and if not `None`, + data will be downloaded directly and `source` will not be checked. progress: whether to display a progress bar. extract: whether to extract the downloaded file. This argument only works when download a single file. args_file: a JSON or YAML file to provide default values for all the args in this function. @@ -206,21 +224,22 @@ def download( """ _args = _update_args( args=args_file, - repo=repo, package=package, bundle_dir=bundle_dir, filename=filename, source=source, + repo=repo, + url=url, progress=progress, extract=extract, ) - if "repo" not in _args: - raise ValueError(f"`repo` is required for 'monai.bundle download'.\n{download.__doc__}") - if "package" not in _args: - raise ValueError(f"`package` is required for 'monai.bundle download'.\n{download.__doc__}") + if "package" not in _args and url is None: + raise ValueError(f"To download from source: {source}, `package` must be provided.") + if "repo" not in _args and url is None: + raise ValueError(f"To download from source: {source}, `repo` must be provided.") _log_input_summary(tag="download", args=_args) - repo_, package_, bundle_dir_, filename_, source_, progress_, extract_ = _pop_args( - _args, "repo", "package", bundle_dir=None, filename=None, source="github", progress=True, extract=False + package_, bundle_dir_, filename_, source_, repo_, url_, progress_, extract_ = _pop_args( + _args, package=None, bundle_dir=None, filename=None, source="github", repo=None, url=None, progress=True, extract=False ) if not bundle_dir_: @@ -231,17 +250,81 @@ def download( raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir_ = Path(bundle_dir_) - if source_ == "github": + if url_ is not None: + if filename_ is not None: + filepath = bundle_dir_ / f"{filename_}" + download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) + if extract_ is True: + extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) + else: + download_url(url=url_, filepath=bundle_dir_, hash_val=None, progress=progress_) + if extract_ is True: + logger.info("When download a whole bundle package, extract is not supported, skip extracting.") + elif source_ == "github": _download_from_github( repo=repo_, tag_name=package_, - bundle_dir=bundle_dir_, + download_path=bundle_dir_, filename=filename_, progress=progress_, extract=extract_, ) else: - raise NotImplementedError("So far, only `github` source is supported.") + raise NotImplementedError("So far, only support to download from url or `github` source.") + + +def load( + weights_name: str = "model.pt", + is_ts_model: bool = False, + package: Optional[str] = None, + bundle_dir: PathLike = ".", + source: str = "github", + repo: Optional[str] = None, + url: Optional[str] = None, + progress: bool = True, + extract: bool = False, + map_location=None, +): + """ + Download (if necessary) and load model weights. + + Args: + weights_name: the name of the weights file that will be loaded. + is_ts_model, a flag to specify if the weights file is a TorchScript module. + package: the bundle package name. + If the weights need to be downloaded first and `url` is `None`, it must be provided. + bundle_dir: the directory the weights will be loaded from. + source: the place that saved the bundle package. + If `source` is `github`, the bundle package should be within the releases. + repo: the repo name. If the weights need to be downloaded first and `url` is `None`, it must be provided. + url: the url to download the data. + progress: whether to display a progress bar when downloading. + extract: whether to extract the downloaded file. + map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. + + """ + model_file_path = os.path.join(bundle_dir, weights_name) + if not os.path.exists(model_file_path): + if package is None and url is None: + raise ValueError(f"To download and load model from source: {source}, `package` must be provided.") + if repo is None and url is None: + raise ValueError(f"To download and load model from source: {source}, `repo` must be provided.") + download( + package=package, + bundle_dir=bundle_dir, + filename=weights_name, + source=source, + repo=repo, + url=url, + progress=progress, + extract=extract, + ) + # loading with `torch.jit.load` + if is_ts_model is True: + return torch.jit.load(model_file_path, map_location=map_location) + # loading with `torch.load` + model_dict = torch.load(model_file_path, map_location=map_location) + return model_dict def run( From b9fc7dcaef3cdcb3ce05b9e870d81478fac2016d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Apr 2022 11:28:29 +0000 Subject: [PATCH 04/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 980bae2f19..2d7d194d96 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -197,7 +197,7 @@ def download( python -m monai.bundle download --package --source "github" --repo # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --package --filename --repo + python -m monai.bundle download --package --filename --repo # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: From 981a990745bab65e4a3dd4d59b092ebf5b4edc76 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 14 Apr 2022 22:37:43 +0800 Subject: [PATCH 05/20] rename args and remove a few places Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 83 +++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 53 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 980bae2f19..5222c5c445 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -151,41 +151,34 @@ def _download_from_github( download_path: Path, filename: Optional[str] = None, progress: bool = True, - extract: bool = False, ): if len(repo.split("/")) != 2: raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") repo_owner, repo_name = repo.split("/") if filename is None: - # download the whole bundle package if filename is not provided + # download the whole bundle if filename is not provided assets_info = _get_git_release_assets(repo_owner, repo_name, tag_name=tag_name) for name, url in assets_info.items(): download_url(url=url, filepath=download_path / f"{name}", hash_val=None, progress=progress) - if extract is True: - logger.info("When download a whole bundle package, extract is not supported, skip extracting.") - logger.info(f"All files within the bundle package {tag_name} are downloaded.") + logger.info(f"All files within the bundle {tag_name} are downloaded.") else: # download a single file url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) - filepath = download_path / f"{filename}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - if extract is True: - extractall(filepath=filepath, output_dir=download_path, has_base=True) + download_url(url=url, filepath=download_path / f"{filename}", hash_val=None, progress=progress) def download( - package: Optional[str] = None, + name: Optional[str] = None, bundle_dir: Optional[PathLike] = None, filename: Optional[str] = None, source: str = "github", repo: Optional[str] = None, url: Optional[str] = None, progress: bool = True, - extract: bool = False, args_file: Optional[str] = None, ): """ - download the bundle package or a file that belongs to the package from the specified source or url. + download the bundle or a file that belongs to the bundle from the specified source or url. This function refers to: https://pytorch.org/docs/stable/_modules/torch/hub.html @@ -193,81 +186,68 @@ def download( .. code-block:: bash - # Execute this module as a CLI entry, and download the whole bundle package: - python -m monai.bundle download --package --source "github" --repo + # Execute this module as a CLI entry, and download the whole bundle: + python -m monai.bundle download --name --source "github" --repo # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --package --filename --repo + python -m monai.bundle download --name --filename --repo # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: python -m monai.bundle download --args_file "/workspace/data/args.json" --filename Args: - package: the bundle package name. If `None` and `url` is `None`, it must be provided in `args_file`. + name: the bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be the same as the release tag. - bundle_dir: target directory to store the download data. Default is `bundle` subfolder under`torch.hub get_dir()`. + bundle_dir: target directory to store the download data. If `None`, it must be provided in `args_file`. filename: the filename that needs to be downloaded. - If `source` is `github` and filename is `None`, the whole bundle package will be downloaded. - source: the place that saved the bundle package. - If `source` is `github`, the bundle package should be within the releases. + If `source` is `github` and filename is `None`, the whole bundle will be downloaded. + source: the place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. repo: the repo name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be in the form of `repo_owner/repo_name`. For example: `Project-MONAI/MONAI`. url: the url to download the data. It is an optional argument and if not `None`, data will be downloaded directly and `source` will not be checked. progress: whether to display a progress bar. - extract: whether to extract the downloaded file. This argument only works when download a single file. args_file: a JSON or YAML file to provide default values for all the args in this function. so that the command line inputs can be simplified. """ _args = _update_args( args=args_file, - package=package, + name=name, bundle_dir=bundle_dir, filename=filename, source=source, repo=repo, url=url, progress=progress, - extract=extract, ) - if "package" not in _args and url is None: - raise ValueError(f"To download from source: {source}, `package` must be provided.") + if "name" not in _args and url is None: + raise ValueError(f"To download from source: {source}, `name` must be provided.") + if "bundle_dir" not in _args: + raise ValueError(f"`bundle_dir` is required for 'monai.bundle download'.\n{run.__doc__}.") if "repo" not in _args and url is None: raise ValueError(f"To download from source: {source}, `repo` must be provided.") _log_input_summary(tag="download", args=_args) - package_, bundle_dir_, filename_, source_, repo_, url_, progress_, extract_ = _pop_args( - _args, package=None, bundle_dir=None, filename=None, source="github", repo=None, url=None, progress=True, extract=False + bundle_dir_, name_, filename_, source_, repo_, url_, progress_ = _pop_args( + _args, "bundle_dir", name=None, filename=None, source="github", repo=None, url=None, progress=True ) - if not bundle_dir_: - get_dir, has_home = optional_import("torch.hub", name="get_dir") - if has_home: - bundle_dir_ = Path(get_dir()) / "bundle" - else: - raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir_ = Path(bundle_dir_) if url_ is not None: if filename_ is not None: - filepath = bundle_dir_ / f"{filename_}" - download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) - if extract_ is True: - extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) - else: - download_url(url=url_, filepath=bundle_dir_, hash_val=None, progress=progress_) - if extract_ is True: - logger.info("When download a whole bundle package, extract is not supported, skip extracting.") + bundle_dir_ = bundle_dir_ / f"{filename_}" + download_url(url=url_, filepath=bundle_dir_, hash_val=None, progress=progress_) elif source_ == "github": _download_from_github( repo=repo_, - tag_name=package_, + tag_name=name_, download_path=bundle_dir_, filename=filename_, progress=progress_, - extract=extract_, ) else: raise NotImplementedError("So far, only support to download from url or `github` source.") @@ -276,13 +256,12 @@ def download( def load( weights_name: str = "model.pt", is_ts_model: bool = False, - package: Optional[str] = None, + name: Optional[str] = None, bundle_dir: PathLike = ".", source: str = "github", repo: Optional[str] = None, url: Optional[str] = None, progress: bool = True, - extract: bool = False, map_location=None, ): """ @@ -291,33 +270,31 @@ def load( Args: weights_name: the name of the weights file that will be loaded. is_ts_model, a flag to specify if the weights file is a TorchScript module. - package: the bundle package name. + name: the bundle name. If the weights need to be downloaded first and `url` is `None`, it must be provided. bundle_dir: the directory the weights will be loaded from. - source: the place that saved the bundle package. - If `source` is `github`, the bundle package should be within the releases. + source: the place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. repo: the repo name. If the weights need to be downloaded first and `url` is `None`, it must be provided. url: the url to download the data. progress: whether to display a progress bar when downloading. - extract: whether to extract the downloaded file. map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. """ model_file_path = os.path.join(bundle_dir, weights_name) if not os.path.exists(model_file_path): - if package is None and url is None: - raise ValueError(f"To download and load model from source: {source}, `package` must be provided.") + if name is None and url is None: + raise ValueError(f"To download and load model from source: {source}, `name` must be provided.") if repo is None and url is None: raise ValueError(f"To download and load model from source: {source}, `repo` must be provided.") download( - package=package, + name=name, bundle_dir=bundle_dir, filename=weights_name, source=source, repo=repo, url=url, progress=progress, - extract=extract, ) # loading with `torch.jit.load` if is_ts_model is True: From 2769ad2cafec3a5fe413784bac0745c890d166ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Apr 2022 14:40:25 +0000 Subject: [PATCH 06/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5222c5c445..2a49b4fc25 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -190,7 +190,7 @@ def download( python -m monai.bundle download --name --source "github" --repo # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --name --filename --repo + python -m monai.bundle download --name --filename --repo # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: From 31e9280ffca77376f3f8fd63e6eaf31d5d4feb06 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 14 Apr 2022 22:42:20 +0800 Subject: [PATCH 07/20] fix flake8 issue Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5222c5c445..77c38e476e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -21,9 +21,7 @@ import torch from torch.cuda import is_available -import warnings - -from monai.apps.utils import download_url, extractall, get_logger +from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike from monai.data import save_net_with_metadata @@ -190,7 +188,7 @@ def download( python -m monai.bundle download --name --source "github" --repo # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --name --filename --repo + python -m monai.bundle download --name --filename --repo # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: From 95227b1df225543d06a65deae8fa5ed587e7451c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 18 Apr 2022 21:03:10 +0800 Subject: [PATCH 08/20] enhance with reviews Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 130 +++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 63 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 77c38e476e..0166cde460 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -21,8 +21,9 @@ import torch from torch.cuda import is_available -from monai.apps.utils import download_url, get_logger +from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_parser import ConfigParser +from monai.bundle.utils import ID_SEP_KEY from monai.config import IgniteInfo, PathLike from monai.data import save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state @@ -119,20 +120,11 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int return tuple(ret) -def _get_git_release_url( - repo_owner: str, - repo_name: str, - tag_name: str, - filename: str, -): +def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str): return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" -def _get_git_release_assets( - repo_owner: str, - repo_name: str, - tag_name: str, -): +def _get_git_release_assets(repo_owner: str, repo_name: str, tag_name: str): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag_name}" resp = requests_get(url) resp.raise_for_status() @@ -144,11 +136,7 @@ def _get_git_release_assets( def _download_from_github( - repo: str, - tag_name: str, - download_path: Path, - filename: Optional[str] = None, - progress: bool = True, + repo: str, tag_name: str, download_path: Path, filename: Optional[str] = None, progress: bool = True ): if len(repo.split("/")) != 2: raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") @@ -157,18 +145,23 @@ def _download_from_github( # download the whole bundle if filename is not provided assets_info = _get_git_release_assets(repo_owner, repo_name, tag_name=tag_name) for name, url in assets_info.items(): - download_url(url=url, filepath=download_path / f"{name}", hash_val=None, progress=progress) + filepath = download_path / f"{name}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + if filepath.name.endswith(("zip", "tar", "tar.gz")): + extractall(filepath=filepath, output_dir=download_path, has_base=True) logger.info(f"All files within the bundle {tag_name} are downloaded.") else: # download a single file url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) - download_url(url=url, filepath=download_path / f"{filename}", hash_val=None, progress=progress) + filepath = download_path / f"{filename}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + if filepath.name.endswith(("zip", "tar", "tar.gz")): + extractall(filepath=filepath, output_dir=download_path, has_base=True) def download( name: Optional[str] = None, bundle_dir: Optional[PathLike] = None, - filename: Optional[str] = None, source: str = "github", repo: Optional[str] = None, url: Optional[str] = None, @@ -177,6 +170,7 @@ def download( ): """ download the bundle or a file that belongs to the bundle from the specified source or url. + "zip", "tar" and "tar.gz" files, will be extracted after downloading. This function refers to: https://pytorch.org/docs/stable/_modules/torch/hub.html @@ -185,77 +179,77 @@ def download( .. code-block:: bash # Execute this module as a CLI entry, and download the whole bundle: - python -m monai.bundle download --name --source "github" --repo + python -m monai.bundle download --name --source "github" --repo # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --name --filename --repo + python -m monai.bundle download --name --repo # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: - python -m monai.bundle download --args_file "/workspace/data/args.json" --filename + python -m monai.bundle download --args_file "/workspace/data/args.json" --repo Args: name: the bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be the same as the release tag. + If only a single file is expected to be downloaded, `name` should be a string that consisted with + the bundle name, a separator `#` and the weights name. For example: if the bundle name is "spleen" and + the weights name is "model.pt", then `name` is "spleen#model.pt". bundle_dir: target directory to store the download data. If `None`, it must be provided in `args_file`. - filename: the filename that needs to be downloaded. - If `source` is `github` and filename is `None`, the whole bundle will be downloaded. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. repo: the repo name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be in the form of `repo_owner/repo_name`. For example: `Project-MONAI/MONAI`. - url: the url to download the data. It is an optional argument and if not `None`, - data will be downloaded directly and `source` will not be checked. + url: the url to download the data. If not `None`, data will be downloaded directly + and `source` will not be checked. + If `name` contains the filename, it will be used as the downloaded filename (without postfix). + Otherwise, the filename is determined by `monai.apps.utils._basename(url)`. progress: whether to display a progress bar. args_file: a JSON or YAML file to provide default values for all the args in this function. so that the command line inputs can be simplified. """ _args = _update_args( - args=args_file, - name=name, - bundle_dir=bundle_dir, - filename=filename, - source=source, - repo=repo, - url=url, - progress=progress, + args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress ) - if "name" not in _args and url is None: - raise ValueError(f"To download from source: {source}, `name` must be provided.") if "bundle_dir" not in _args: raise ValueError(f"`bundle_dir` is required for 'monai.bundle download'.\n{run.__doc__}.") + if "name" not in _args and url is None: + raise ValueError(f"To download from source: {source}, `name` must be provided.") if "repo" not in _args and url is None: raise ValueError(f"To download from source: {source}, `repo` must be provided.") _log_input_summary(tag="download", args=_args) - bundle_dir_, name_, filename_, source_, repo_, url_, progress_ = _pop_args( - _args, "bundle_dir", name=None, filename=None, source="github", repo=None, url=None, progress=True + bundle_dir_, name_, source_, repo_, url_, progress_ = _pop_args( + _args, "bundle_dir", name=None, source="github", repo=None, url=None, progress=True ) bundle_dir_ = Path(bundle_dir_) + filename: Optional[str] = None + if name_ is not None and len(name_.split(ID_SEP_KEY)) == 2: + name_, filename = name_.split(ID_SEP_KEY) + if url_ is not None: - if filename_ is not None: - bundle_dir_ = bundle_dir_ / f"{filename_}" - download_url(url=url_, filepath=bundle_dir_, hash_val=None, progress=progress_) + if filename is not None: + filepath = bundle_dir_ / f"{filename}" + else: + filepath = bundle_dir_ / f"{_basename(url_)}" + download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) + filepath = Path(filepath) + if filepath.name.endswith(("zip", "tar", "tar.gz")): + extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) elif source_ == "github": _download_from_github( - repo=repo_, - tag_name=name_, - download_path=bundle_dir_, - filename=filename_, - progress=progress_, + repo=repo_, tag_name=name_, download_path=bundle_dir_, filename=filename, progress=progress_ ) else: raise NotImplementedError("So far, only support to download from url or `github` source.") def load( - weights_name: str = "model.pt", - is_ts_model: bool = False, name: Optional[str] = None, - bundle_dir: PathLike = ".", + is_ts_model: bool = False, + bundle_dir: Optional[PathLike] = None, source: str = "github", repo: Optional[str] = None, url: Optional[str] = None, @@ -263,14 +257,17 @@ def load( map_location=None, ): """ - Download (if necessary) and load model weights. + Load model weights. If the weights file is not existing locally, it will be downloaded first. Args: - weights_name: the name of the weights file that will be loaded. + name: Bundle and weights name. If `None`, `url` should be provided, or the weights file is existing locally and + the default weights file should be named as "model.pt". + If not `None`, it should be a string that consisted with the bundle name, a separator `#` + and the weights name. For example: if the bundle name is "spleen" and the weights name is "model.pt", then + `name` is "spleen#model.pt". is_ts_model, a flag to specify if the weights file is a TorchScript module. - name: the bundle name. - If the weights need to be downloaded first and `url` is `None`, it must be provided. bundle_dir: the directory the weights will be loaded from. + Default is `bundle` subfolder under`torch.hub get_dir()`. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. repo: the repo name. If the weights need to be downloaded first and `url` is `None`, it must be provided. @@ -279,21 +276,28 @@ def load( map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. """ + if bundle_dir is None: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + bundle_dir = Path(bundle_dir) + + weights_name: str = "model.pt" + if name is not None: + if len(name.split(ID_SEP_KEY)) == 2: + weights_name = name.split(ID_SEP_KEY)[1] + else: + raise ValueError(f"The format of `name` is wrong.\n{run.__doc__}") model_file_path = os.path.join(bundle_dir, weights_name) + if not os.path.exists(model_file_path): if name is None and url is None: raise ValueError(f"To download and load model from source: {source}, `name` must be provided.") if repo is None and url is None: raise ValueError(f"To download and load model from source: {source}, `repo` must be provided.") - download( - name=name, - bundle_dir=bundle_dir, - filename=weights_name, - source=source, - repo=repo, - url=url, - progress=progress, - ) + download(name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress) # loading with `torch.jit.load` if is_ts_model is True: return torch.jit.load(model_file_path, map_location=map_location) From 0e530352525f73ff0b7abb1d30980cd901ed7a46 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 18 Apr 2022 22:13:56 +0800 Subject: [PATCH 09/20] add instantiate for load Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 0166cde460..dbea1a5169 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -22,6 +22,7 @@ from torch.cuda import is_available from monai.apps.utils import _basename, download_url, extractall, get_logger +from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.bundle.utils import ID_SEP_KEY from monai.config import IgniteInfo, PathLike @@ -255,9 +256,12 @@ def load( url: Optional[str] = None, progress: bool = True, map_location=None, + net_name: Optional[str] = None, + **net_kwargs, ): """ - Load model weights. If the weights file is not existing locally, it will be downloaded first. + Load model weights. If the weights file is not existing locally, it will be downloaded first. The function can + return the weights, or an instantiated network that loaded the weights. Args: name: Bundle and weights name. If `None`, `url` should be provided, or the weights file is existing locally and @@ -274,6 +278,8 @@ def load( url: the url to download the data. progress: whether to display a progress bar when downloading. map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. + net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. """ if bundle_dir is None: @@ -303,7 +309,15 @@ def load( return torch.jit.load(model_file_path, map_location=map_location) # loading with `torch.load` model_dict = torch.load(model_file_path, map_location=map_location) - return model_dict + + if net_name is not None: + net_config = {"_target_": net_name} + for k, v in net_kwargs.items(): + net_config[k] = v + configer = ConfigComponent(config=net_config) + model = configer.instantiate() + model.load_state_dict(model_dict) # type: ignore + return model def run( From 5cad06483985deabbd694bd9da194c38da492259 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 18 Apr 2022 22:21:10 +0800 Subject: [PATCH 10/20] fix black error Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index dbea1a5169..a1b4171ff5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -316,7 +316,7 @@ def load( net_config[k] = v configer = ConfigComponent(config=net_config) model = configer.instantiate() - model.load_state_dict(model_dict) # type: ignore + model.load_state_dict(model_dict) # type: ignore return model From 3ece56728e0d002b7a0864eddd55be4c3c246536 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 19 Apr 2022 17:30:18 +0800 Subject: [PATCH 11/20] add unittest Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 29 ++++--- tests/test_bundle_download.py | 140 ++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 tests/test_bundle_download.py diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a1b4171ff5..bc845b7bb2 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -180,14 +180,17 @@ def download( .. code-block:: bash # Execute this module as a CLI entry, and download the whole bundle: - python -m monai.bundle download --name --source "github" --repo + python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name" # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --name --repo + python -m monai.bundle download --name "bundle_name#filename" --repo "repo_owner/repo_name" + + # Execute this module as a CLI entry, and download a single file via URL: + python -m monai.bundle download --url # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: - python -m monai.bundle download --args_file "/workspace/data/args.json" --repo + python -m monai.bundle download --args_file "/workspace/data/args.json" --repo "repo_owner/repo_name" Args: name: the bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. @@ -195,7 +198,8 @@ def download( If only a single file is expected to be downloaded, `name` should be a string that consisted with the bundle name, a separator `#` and the weights name. For example: if the bundle name is "spleen" and the weights name is "model.pt", then `name` is "spleen#model.pt". - bundle_dir: target directory to store the download data. If `None`, it must be provided in `args_file`. + bundle_dir: target directory to store the download data. + Default is `bundle` subfolder under`torch.hub get_dir()`. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. repo: the repo name. If `None` and `url` is `None`, it must be provided in `args_file`. @@ -213,17 +217,22 @@ def download( _args = _update_args( args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress ) - if "bundle_dir" not in _args: - raise ValueError(f"`bundle_dir` is required for 'monai.bundle download'.\n{run.__doc__}.") + if "name" not in _args and url is None: raise ValueError(f"To download from source: {source}, `name` must be provided.") if "repo" not in _args and url is None: raise ValueError(f"To download from source: {source}, `repo` must be provided.") _log_input_summary(tag="download", args=_args) - bundle_dir_, name_, source_, repo_, url_, progress_ = _pop_args( - _args, "bundle_dir", name=None, source="github", repo=None, url=None, progress=True + name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( + _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True ) + if bundle_dir_ is None: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir_ = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir_ = Path(bundle_dir_) filename: Optional[str] = None @@ -279,6 +288,7 @@ def load( progress: whether to display a progress bar when downloading. map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + This argument only works when loading weights. net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. """ @@ -317,7 +327,8 @@ def load( configer = ConfigComponent(config=net_config) model = configer.instantiate() model.load_state_dict(model_dict) # type: ignore - return model + return model + return model_dict def run( diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py new file mode 100644 index 0000000000..3eec421dd2 --- /dev/null +++ b/tests/test_bundle_download.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import tempfile +import unittest + +import torch +from parameterized import parameterized + +import monai.networks.nets as nets +from monai.apps import check_hash +from monai.bundle import ConfigParser, download, load +from tests.utils import skip_if_windows + +TEST_CASE_1 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_2 = [ + "test_bundle#network.json", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/test_bundle/network.json", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_3 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data", +] + +TEST_CASE_4 = [ + "test_bundle#model.ts", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/test_bundle/model.ts", + ["test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data", +] + + +@skip_if_windows +class TestDownload(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val): + # download a whole bundle from github releases + with tempfile.TemporaryDirectory() as tempdir: + bundle_dir = os.path.join(tempdir, "test_bundle") + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] + cmd += ["--bundle_dir", bundle_dir, "--repo", repo, "--progress", "False"] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(bundle_dir, file) + self.assertTrue(os.path.exists(file_path)) + # check the md5 hash of the json file + if file == bundle_files[2]: + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + @parameterized.expand([TEST_CASE_2]) + def test_url_download_bundle(self, bundle_name, url, hash_val): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url] + subprocess.check_call(cmd) + file_path = os.path.join(tempdir, "network.json") + self.assertTrue(os.path.exists(file_path)) + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + +class TestLoad(unittest.TestCase): + @parameterized.expand([TEST_CASE_3]) + def test_load_weights(self, bundle_files, bundle_name, repo): + # download bundle, and load weights from the downloaded path + with tempfile.TemporaryDirectory() as tempdir: + # download bundle + download(name=bundle_name, repo=repo, bundle_dir=tempdir, progress=False) + + # load weights only + weights_name = bundle_name + "#" + bundle_files[0] + weights = load(name=weights_name, bundle_dir=tempdir, progress=False) + + # prepare network + with open(os.path.join(tempdir, bundle_files[2])) as f: + net_args = json.load(f)["network_def"] + model_name = net_args["_target_"] + del net_args["_target_"] + model = nets.__dict__[model_name](**net_args) + model.load_state_dict(weights) + model.eval() + + # prepare data and test + input_tensor = torch.load(os.path.join(tempdir, bundle_files[4])) + output = model.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_files[3])) + torch.testing.assert_allclose(output, expected_output) + + # load instantiated model directly and test + model_2 = load(name=weights_name, bundle_dir=tempdir, progress=False, net_name=model_name, **net_args) + model_2.eval() + output_2 = model_2.forward(input_tensor) + torch.testing.assert_allclose(output_2, expected_output) + + @parameterized.expand([TEST_CASE_4]) + def test_load_ts_module(self, ts_name, url, bundle_files, bundle_name, repo): + # load ts module from url, and download input and output tensors for testing + with tempfile.TemporaryDirectory() as tempdir: + # load ts module + model_ts = load(name=ts_name, is_ts_model=True, bundle_dir=tempdir, url=url, progress=False) + + # download input and output tensors + for file in bundle_files: + download_name = bundle_name + "#" + file + download(name=download_name, repo=repo, bundle_dir=tempdir, progress=False) + + # prepare and test + input_tensor = torch.load(os.path.join(tempdir, bundle_files[1])) + output = model_ts.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_files[0])) + torch.testing.assert_allclose(output, expected_output) + + +if __name__ == "__main__": + unittest.main() From 15dc4b6b9e125737165d640f5a8904cf1e5af258 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 19 Apr 2022 17:34:43 +0800 Subject: [PATCH 12/20] add load to docs Signed-off-by: Yiheng Wang --- docs/source/bundle.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 023effaad1..a28db04091 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -37,6 +37,7 @@ Model Bundle --------- .. autofunction:: ckpt_export .. autofunction:: download +.. autofunction:: load .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out From 12d400a98dc95e6dc389521abfac5a7479d77153 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 19 Apr 2022 18:03:58 +0800 Subject: [PATCH 13/20] add skip Signed-off-by: Yiheng Wang --- tests/test_bundle_download.py | 146 ++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 69 deletions(-) diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 3eec421dd2..4bf389070f 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -21,7 +21,7 @@ import monai.networks.nets as nets from monai.apps import check_hash from monai.bundle import ConfigParser, download, load -from tests.utils import skip_if_windows +from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows TEST_CASE_1 = [ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], @@ -54,86 +54,94 @@ @skip_if_windows class TestDownload(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) + @skip_if_quick def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val): - # download a whole bundle from github releases - with tempfile.TemporaryDirectory() as tempdir: - bundle_dir = os.path.join(tempdir, "test_bundle") - cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] - cmd += ["--bundle_dir", bundle_dir, "--repo", repo, "--progress", "False"] - subprocess.check_call(cmd) - for file in bundle_files: - file_path = os.path.join(bundle_dir, file) - self.assertTrue(os.path.exists(file_path)) - # check the md5 hash of the json file - if file == bundle_files[2]: - self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + with skip_if_downloading_fails(): + # download a whole bundle from github releases + with tempfile.TemporaryDirectory() as tempdir: + bundle_dir = os.path.join(tempdir, "test_bundle") + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] + cmd += ["--bundle_dir", bundle_dir, "--repo", repo, "--progress", "False"] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(bundle_dir, file) + self.assertTrue(os.path.exists(file_path)) + # check the md5 hash of the json file + if file == bundle_files[2]: + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) @parameterized.expand([TEST_CASE_2]) + @skip_if_quick def test_url_download_bundle(self, bundle_name, url, hash_val): - # download a single file from url, also use `args_file` - with tempfile.TemporaryDirectory() as tempdir: - def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} - def_args_file = os.path.join(tempdir, "def_args.json") - parser = ConfigParser() - parser.export_config_file(config=def_args, filepath=def_args_file) - cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] - cmd += ["--url", url] - subprocess.check_call(cmd) - file_path = os.path.join(tempdir, "network.json") - self.assertTrue(os.path.exists(file_path)) - self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + with skip_if_downloading_fails(): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url] + subprocess.check_call(cmd) + file_path = os.path.join(tempdir, "network.json") + self.assertTrue(os.path.exists(file_path)) + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) class TestLoad(unittest.TestCase): @parameterized.expand([TEST_CASE_3]) + @skip_if_quick def test_load_weights(self, bundle_files, bundle_name, repo): - # download bundle, and load weights from the downloaded path - with tempfile.TemporaryDirectory() as tempdir: - # download bundle - download(name=bundle_name, repo=repo, bundle_dir=tempdir, progress=False) - - # load weights only - weights_name = bundle_name + "#" + bundle_files[0] - weights = load(name=weights_name, bundle_dir=tempdir, progress=False) - - # prepare network - with open(os.path.join(tempdir, bundle_files[2])) as f: - net_args = json.load(f)["network_def"] - model_name = net_args["_target_"] - del net_args["_target_"] - model = nets.__dict__[model_name](**net_args) - model.load_state_dict(weights) - model.eval() - - # prepare data and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[4])) - output = model.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[3])) - torch.testing.assert_allclose(output, expected_output) - - # load instantiated model directly and test - model_2 = load(name=weights_name, bundle_dir=tempdir, progress=False, net_name=model_name, **net_args) - model_2.eval() - output_2 = model_2.forward(input_tensor) - torch.testing.assert_allclose(output_2, expected_output) + with skip_if_downloading_fails(): + # download bundle, and load weights from the downloaded path + with tempfile.TemporaryDirectory() as tempdir: + # download bundle + download(name=bundle_name, repo=repo, bundle_dir=tempdir, progress=False) + + # load weights only + weights_name = bundle_name + "#" + bundle_files[0] + weights = load(name=weights_name, bundle_dir=tempdir, progress=False) + + # prepare network + with open(os.path.join(tempdir, bundle_files[2])) as f: + net_args = json.load(f)["network_def"] + model_name = net_args["_target_"] + del net_args["_target_"] + model = nets.__dict__[model_name](**net_args) + model.load_state_dict(weights) + model.eval() + + # prepare data and test + input_tensor = torch.load(os.path.join(tempdir, bundle_files[4])) + output = model.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_files[3])) + torch.testing.assert_allclose(output, expected_output) + + # load instantiated model directly and test + model_2 = load(name=weights_name, bundle_dir=tempdir, progress=False, net_name=model_name, **net_args) + model_2.eval() + output_2 = model_2.forward(input_tensor) + torch.testing.assert_allclose(output_2, expected_output) @parameterized.expand([TEST_CASE_4]) + @skip_if_quick def test_load_ts_module(self, ts_name, url, bundle_files, bundle_name, repo): - # load ts module from url, and download input and output tensors for testing - with tempfile.TemporaryDirectory() as tempdir: - # load ts module - model_ts = load(name=ts_name, is_ts_model=True, bundle_dir=tempdir, url=url, progress=False) - - # download input and output tensors - for file in bundle_files: - download_name = bundle_name + "#" + file - download(name=download_name, repo=repo, bundle_dir=tempdir, progress=False) - - # prepare and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[1])) - output = model_ts.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[0])) - torch.testing.assert_allclose(output, expected_output) + with skip_if_downloading_fails(): + # load ts module from url, and download input and output tensors for testing + with tempfile.TemporaryDirectory() as tempdir: + # load ts module + model_ts = load(name=ts_name, is_ts_model=True, bundle_dir=tempdir, url=url, progress=False) + + # download input and output tensors + for file in bundle_files: + download_name = bundle_name + "#" + file + download(name=download_name, repo=repo, bundle_dir=tempdir, progress=False) + + # prepare and test + input_tensor = torch.load(os.path.join(tempdir, bundle_files[1])) + output = model_ts.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_files[0])) + torch.testing.assert_allclose(output, expected_output) if __name__ == "__main__": From 78bbfa6234ecd79f3e075e8f8aab214622302d95 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 19 Apr 2022 19:03:08 +0800 Subject: [PATCH 14/20] add schemaerror Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index bc845b7bb2..31dfb089ff 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -31,7 +31,7 @@ from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") -ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") +exceptions, _ = optional_import("jsonschema", name="exceptions") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") requests_get, has_requests = optional_import("requests", name="get") @@ -464,10 +464,13 @@ def verify_metadata( try: # the rest key-values in the _args are for `validate` API validate(instance=metadata, schema=schema, **_args) - except ValidationError as e: + except exceptions.ValidationError as e: # as the error message is very long, only extract the key information logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return + except exceptions.SchemaError as e: + logger.info(str(e)) + return logger.info("metadata is verified with no error.") From 780ecb70d641077e41fc370a6010127920cddc76 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 20 Apr 2022 19:23:27 +0800 Subject: [PATCH 15/20] fix partial places Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 88 ++++++++++++++++------------------- tests/test_bundle_download.py | 35 ++++++++------ 2 files changed, 61 insertions(+), 62 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 31dfb089ff..b026a97530 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -148,16 +148,12 @@ def _download_from_github( for name, url in assets_info.items(): filepath = download_path / f"{name}" download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - if filepath.name.endswith(("zip", "tar", "tar.gz")): - extractall(filepath=filepath, output_dir=download_path, has_base=True) logger.info(f"All files within the bundle {tag_name} are downloaded.") else: # download a single file url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) filepath = download_path / f"{filename}" download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - if filepath.name.endswith(("zip", "tar", "tar.gz")): - extractall(filepath=filepath, output_dir=download_path, has_base=True) def download( @@ -196,9 +192,9 @@ def download( name: the bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be the same as the release tag. If only a single file is expected to be downloaded, `name` should be a string that consisted with - the bundle name, a separator `#` and the weights name. For example: if the bundle name is "spleen" and - the weights name is "model.pt", then `name` is "spleen#model.pt". - bundle_dir: target directory to store the download data. + the bundle name, a separator `#` and the filename. For example: if the bundle name is "spleen" and + the filename is "model.pt", then `name` is "spleen#model.pt". + bundle_dir: target directory to store the downloaded data. Default is `bundle` subfolder under`torch.hub get_dir()`. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. @@ -218,10 +214,8 @@ def download( args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress ) - if "name" not in _args and url is None: - raise ValueError(f"To download from source: {source}, `name` must be provided.") - if "repo" not in _args and url is None: - raise ValueError(f"To download from source: {source}, `repo` must be provided.") + if ("name" not in _args or "repo" not in _args) and url is None: + raise ValueError(f"To download from source: {source}, `name`and `repo`must be provided, got {name} and {repo}.") _log_input_summary(tag="download", args=_args) name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True @@ -253,40 +247,40 @@ def download( repo=repo_, tag_name=name_, download_path=bundle_dir_, filename=filename, progress=progress_ ) else: - raise NotImplementedError("So far, only support to download from url or `github` source.") + raise NotImplementedError( + f"Currently only download from provided URL in `url` or Github is implemented, got source: {source}." + ) def load( - name: Optional[str] = None, + name: str, is_ts_model: bool = False, bundle_dir: Optional[PathLike] = None, source: str = "github", repo: Optional[str] = None, - url: Optional[str] = None, progress: bool = True, - map_location=None, + device: str = "cpu", net_name: Optional[str] = None, **net_kwargs, ): """ - Load model weights. If the weights file is not existing locally, it will be downloaded first. The function can - return the weights, or an instantiated network that loaded the weights. + Load model weights or TorchScript module. If the weights file does not exist locally, it will be downloaded first. + The function can return weights, an instantiated network that loaded the weights, or a TorchScript module. Args: - name: Bundle and weights name. If `None`, `url` should be provided, or the weights file is existing locally and - the default weights file should be named as "model.pt". - If not `None`, it should be a string that consisted with the bundle name, a separator `#` - and the weights name. For example: if the bundle name is "spleen" and the weights name is "model.pt", then - `name` is "spleen#model.pt". - is_ts_model, a flag to specify if the weights file is a TorchScript module. + name: can be a string that contains only the bundle name, or a string that consists with the + bundle name, a separator `#` and the weights name. + For example, `name="spleen#model.pt"` means the bundle name is `spleen` and the weights name is `model.pt`. + If the weights name is not contained, `model.pt` or `model.ts` (according to the argument `is_ts_model`) + will be used as the default weights name. + is_ts_model: a flag to specify if the weights file is a TorchScript module. bundle_dir: the directory the weights will be loaded from. Default is `bundle` subfolder under`torch.hub get_dir()`. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. repo: the repo name. If the weights need to be downloaded first and `url` is `None`, it must be provided. - url: the url to download the data. progress: whether to display a progress bar when downloading. - map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. + device: target device of returned weights or module. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. This argument only works when loading weights. net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. @@ -300,35 +294,31 @@ def load( raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir = Path(bundle_dir) - weights_name: str = "model.pt" - if name is not None: - if len(name.split(ID_SEP_KEY)) == 2: - weights_name = name.split(ID_SEP_KEY)[1] - else: - raise ValueError(f"The format of `name` is wrong.\n{run.__doc__}") + weights_name: str = "model.ts" if is_ts_model is True else "model.pt" + if len(name.split(ID_SEP_KEY)) == 2: + weights_name = name.split(ID_SEP_KEY)[1] + elif len(name.split(ID_SEP_KEY)) == 1: + name = name + "#" + weights_name + else: + raise ValueError(f"The format of `name` is wrong.\n{run.__doc__}") model_file_path = os.path.join(bundle_dir, weights_name) - if not os.path.exists(model_file_path): - if name is None and url is None: - raise ValueError(f"To download and load model from source: {source}, `name` must be provided.") - if repo is None and url is None: - raise ValueError(f"To download and load model from source: {source}, `repo` must be provided.") - download(name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress) + download(name=name, bundle_dir=bundle_dir, source=source, repo=repo, progress=progress) + # loading with `torch.jit.load` if is_ts_model is True: - return torch.jit.load(model_file_path, map_location=map_location) + return torch.jit.load(model_file_path, map_location=device) # loading with `torch.load` - model_dict = torch.load(model_file_path, map_location=map_location) - - if net_name is not None: - net_config = {"_target_": net_name} - for k, v in net_kwargs.items(): - net_config[k] = v - configer = ConfigComponent(config=net_config) - model = configer.instantiate() - model.load_state_dict(model_dict) # type: ignore - return model - return model_dict + model_dict = torch.load(model_file_path, map_location=device) + + if net_name is None: + return model_dict + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() + model.to(device) # type: ignore + model.load_state_dict(model_dict) # type: ignore + return model def run( diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 4bf389070f..fc1dd29045 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -40,14 +40,14 @@ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], "test_bundle", "Project-MONAI/MONAI-extra-test-data", + "cuda" if torch.cuda.is_available() else "cpu", ] TEST_CASE_4 = [ - "test_bundle#model.ts", - "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/test_bundle/model.ts", ["test_output.pt", "test_input.pt"], "test_bundle", "Project-MONAI/MONAI-extra-test-data", + "cuda" if torch.cuda.is_available() else "cpu", ] @@ -91,7 +91,7 @@ def test_url_download_bundle(self, bundle_name, url, hash_val): class TestLoad(unittest.TestCase): @parameterized.expand([TEST_CASE_3]) @skip_if_quick - def test_load_weights(self, bundle_files, bundle_name, repo): + def test_load_weights(self, bundle_files, bundle_name, repo, device): with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: @@ -100,7 +100,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo): # load weights only weights_name = bundle_name + "#" + bundle_files[0] - weights = load(name=weights_name, bundle_dir=tempdir, progress=False) + weights = load(name=weights_name, bundle_dir=tempdir, progress=False, device=device) # prepare network with open(os.path.join(tempdir, bundle_files[2])) as f: @@ -108,39 +108,48 @@ def test_load_weights(self, bundle_files, bundle_name, repo): model_name = net_args["_target_"] del net_args["_target_"] model = nets.__dict__[model_name](**net_args) + model.to(device) model.load_state_dict(weights) model.eval() # prepare data and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[4])) + input_tensor = torch.load(os.path.join(tempdir, bundle_files[4]), map_location=device) output = model.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[3])) + expected_output = torch.load(os.path.join(tempdir, bundle_files[3]), map_location=device) torch.testing.assert_allclose(output, expected_output) # load instantiated model directly and test - model_2 = load(name=weights_name, bundle_dir=tempdir, progress=False, net_name=model_name, **net_args) + model_2 = load( + name=weights_name, + bundle_dir=tempdir, + progress=False, + device=device, + net_name=model_name, + **net_args, + ) model_2.eval() output_2 = model_2.forward(input_tensor) torch.testing.assert_allclose(output_2, expected_output) @parameterized.expand([TEST_CASE_4]) @skip_if_quick - def test_load_ts_module(self, ts_name, url, bundle_files, bundle_name, repo): + def test_load_ts_module(self, bundle_files, bundle_name, repo, device): with skip_if_downloading_fails(): - # load ts module from url, and download input and output tensors for testing + # load ts module, the module name is not included with tempfile.TemporaryDirectory() as tempdir: # load ts module - model_ts = load(name=ts_name, is_ts_model=True, bundle_dir=tempdir, url=url, progress=False) - + model_ts = load( + name=bundle_name, is_ts_model=True, bundle_dir=tempdir, repo=repo, progress=False, device=device + ) # download input and output tensors for file in bundle_files: download_name = bundle_name + "#" + file download(name=download_name, repo=repo, bundle_dir=tempdir, progress=False) # prepare and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[1])) + input_tensor = torch.load(os.path.join(tempdir, bundle_files[1]), map_location=device) output = model_ts.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[0])) + expected_output = torch.load(os.path.join(tempdir, bundle_files[0]), map_location=device) torch.testing.assert_allclose(output, expected_output) From 627ae480a11a13b195d82f0bc06dafc0bc045b07 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 21 Apr 2022 14:48:53 +0800 Subject: [PATCH 16/20] download zip bundle Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 131 ++++++++++++---------------------- tests/test_bundle_download.py | 67 ++++++++--------- 2 files changed, 73 insertions(+), 125 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 23abecd83e..3ee86f83ad 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -24,7 +24,6 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import ID_SEP_KEY from monai.config import IgniteInfo, PathLike from monai.data import save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state @@ -125,35 +124,16 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" -def _get_git_release_assets(repo_owner: str, repo_name: str, tag_name: str): - url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag_name}" - resp = requests_get(url) - resp.raise_for_status() - assets_list = json.loads(resp.text)["assets"] - assets_info = {} - for asset in assets_list: - assets_info[asset["name"]] = asset["browser_download_url"] - return assets_info - - -def _download_from_github( - repo: str, tag_name: str, download_path: Path, filename: Optional[str] = None, progress: bool = True -): - if len(repo.split("/")) != 2: - raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name`.") - repo_owner, repo_name = repo.split("/") - if filename is None: - # download the whole bundle if filename is not provided - assets_info = _get_git_release_assets(repo_owner, repo_name, tag_name=tag_name) - for name, url in assets_info.items(): - filepath = download_path / f"{name}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - logger.info(f"All files within the bundle {tag_name} are downloaded.") - else: - # download a single file - url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) - filepath = download_path / f"{filename}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) +def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True): + if len(repo.split("/")) != 3: + raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.") + repo_owner, repo_name, tag_name = repo.split("/") + if ".zip" not in filename: + filename += ".zip" + url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) + filepath = download_path / f"{filename}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) def download( @@ -166,8 +146,8 @@ def download( args_file: Optional[str] = None, ): """ - download the bundle or a file that belongs to the bundle from the specified source or url. - "zip", "tar" and "tar.gz" files, will be extracted after downloading. + download bundle from the specified source or url. The bundle should be a zip file and it + will be extracted after downloading. This function refers to: https://pytorch.org/docs/stable/_modules/torch/hub.html @@ -175,36 +155,28 @@ def download( .. code-block:: bash - # Execute this module as a CLI entry, and download the whole bundle: - python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name" - - # Execute this module as a CLI entry, and download a single file: - python -m monai.bundle download --name "bundle_name#filename" --repo "repo_owner/repo_name" + # Execute this module as a CLI entry, and download bundle: + python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag" - # Execute this module as a CLI entry, and download a single file via URL: - python -m monai.bundle download --url + # Execute this module as a CLI entry, and download bundle via URL: + python -m monai.bundle download --name "bundle_name" --url # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: - python -m monai.bundle download --args_file "/workspace/data/args.json" --repo "repo_owner/repo_name" + python -m monai.bundle download --args_file "/workspace/data/args.json" --source "github" Args: - name: the bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. - If `source` is `github`, it should be the same as the release tag. - If only a single file is expected to be downloaded, `name` should be a string that consisted with - the bundle name, a separator `#` and the filename. For example: if the bundle name is "spleen" and - the filename is "model.pt", then `name` is "spleen#model.pt". + name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. bundle_dir: target directory to store the downloaded data. Default is `bundle` subfolder under`torch.hub get_dir()`. - source: the place that saved the bundle. + source: place that saved the bundle. If `source` is `github`, the bundle should be within the releases. - repo: the repo name. If `None` and `url` is `None`, it must be provided in `args_file`. - If `source` is `github`, it should be in the form of `repo_owner/repo_name`. - For example: `Project-MONAI/MONAI`. - url: the url to download the data. If not `None`, data will be downloaded directly + repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. - If `name` contains the filename, it will be used as the downloaded filename (without postfix). - Otherwise, the filename is determined by `monai.apps.utils._basename(url)`. + If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. progress: whether to display a progress bar. args_file: a JSON or YAML file to provide default values for all the args in this function. so that the command line inputs can be simplified. @@ -214,8 +186,6 @@ def download( args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress ) - if ("name" not in _args or "repo" not in _args) and url is None: - raise ValueError(f"To download from source: {source}, `name`and `repo`must be provided, got {name} and {repo}.") _log_input_summary(tag="download", args=_args) name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True @@ -229,32 +199,28 @@ def download( raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir_ = Path(bundle_dir_) - filename: Optional[str] = None - if name_ is not None and len(name_.split(ID_SEP_KEY)) == 2: - name_, filename = name_.split(ID_SEP_KEY) - if url_ is not None: - if filename is not None: - filepath = bundle_dir_ / f"{filename}" + if name is not None: + filepath = bundle_dir_ / f"{name}.zip" else: filepath = bundle_dir_ / f"{_basename(url_)}" download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) - filepath = Path(filepath) - if filepath.name.endswith(("zip", "tar", "tar.gz")): - extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) + extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) elif source_ == "github": - _download_from_github( - repo=repo_, tag_name=name_, download_path=bundle_dir_, filename=filename, progress=progress_ - ) + if name_ is None or repo_ is None: + raise ValueError( + f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}." + ) + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) else: raise NotImplementedError( - f"Currently only download from provided URL in `url` or Github is implemented, got source: {source}." + f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}." ) def load( name: str, - is_ts_model: bool = False, + load_ts_module: bool = False, bundle_dir: Optional[PathLike] = None, source: str = "github", repo: Optional[str] = None, @@ -264,21 +230,20 @@ def load( **net_kwargs, ): """ - Load model weights or TorchScript module. If the weights file does not exist locally, it will be downloaded first. + Load model weights or TorchScript module of a bundle. + If the weights file does not exist locally, it will be downloaded first. The function can return weights, an instantiated network that loaded the weights, or a TorchScript module. Args: - name: can be a string that contains only the bundle name, or a string that consists with the - bundle name, a separator `#` and the weights name. - For example, `name="spleen#model.pt"` means the bundle name is `spleen` and the weights name is `model.pt`. - If the weights name is not contained, `model.pt` or `model.ts` (according to the argument `is_ts_model`) - will be used as the default weights name. - is_ts_model: a flag to specify if the weights file is a TorchScript module. - bundle_dir: the directory the weights will be loaded from. + name: bundle name. + load_ts_module: a flag to specify if loading the TorchScript module. + bundle_dir: the directory the weights/TorchScript module will be loaded from. Default is `bundle` subfolder under`torch.hub get_dir()`. source: the place that saved the bundle. If `source` is `github`, the bundle should be within the releases. - repo: the repo name. If the weights need to be downloaded first and `url` is `None`, it must be provided. + repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. progress: whether to display a progress bar when downloading. device: target device of returned weights or module. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. @@ -294,19 +259,13 @@ def load( raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") bundle_dir = Path(bundle_dir) - weights_name: str = "model.ts" if is_ts_model is True else "model.pt" - if len(name.split(ID_SEP_KEY)) == 2: - weights_name = name.split(ID_SEP_KEY)[1] - elif len(name.split(ID_SEP_KEY)) == 1: - name = name + "#" + weights_name - else: - raise ValueError(f"The format of `name` is wrong.\n{run.__doc__}") - model_file_path = os.path.join(bundle_dir, weights_name) + weights_name: str = "model.ts" if load_ts_module is True else "model.pt" + model_file_path = os.path.join(bundle_dir, name, weights_name) if not os.path.exists(model_file_path): download(name=name, bundle_dir=bundle_dir, source=source, repo=repo, progress=progress) # loading with `torch.jit.load` - if is_ts_model is True: + if load_ts_module is True: return torch.jit.load(model_file_path, map_location=device) # loading with `torch.load` model_dict = torch.load(model_file_path, map_location=device) diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index fc1dd29045..9a377c7cae 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -20,33 +20,34 @@ import monai.networks.nets as nets from monai.apps import check_hash -from monai.bundle import ConfigParser, download, load +from monai.bundle import ConfigParser, load from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows TEST_CASE_1 = [ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], "test_bundle", - "Project-MONAI/MONAI-extra-test-data", + "Project-MONAI/MONAI-extra-test-data/0.8.1", "a131d39a0af717af32d19e565b434928", ] TEST_CASE_2 = [ - "test_bundle#network.json", - "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/test_bundle/network.json", + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip", "a131d39a0af717af32d19e565b434928", ] TEST_CASE_3 = [ ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], "test_bundle", - "Project-MONAI/MONAI-extra-test-data", + "Project-MONAI/MONAI-extra-test-data/0.8.1", "cuda" if torch.cuda.is_available() else "cpu", ] TEST_CASE_4 = [ ["test_output.pt", "test_input.pt"], "test_bundle", - "Project-MONAI/MONAI-extra-test-data", + "Project-MONAI/MONAI-extra-test-data/0.8.1", "cuda" if torch.cuda.is_available() else "cpu", ] @@ -59,20 +60,18 @@ def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val): with skip_if_downloading_fails(): # download a whole bundle from github releases with tempfile.TemporaryDirectory() as tempdir: - bundle_dir = os.path.join(tempdir, "test_bundle") cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] - cmd += ["--bundle_dir", bundle_dir, "--repo", repo, "--progress", "False"] + cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"] subprocess.check_call(cmd) for file in bundle_files: - file_path = os.path.join(bundle_dir, file) + file_path = os.path.join(tempdir, bundle_name, file) self.assertTrue(os.path.exists(file_path)) - # check the md5 hash of the json file - if file == bundle_files[2]: + if file == "network.json": self.assertTrue(check_hash(filepath=file_path, val=hash_val)) @parameterized.expand([TEST_CASE_2]) @skip_if_quick - def test_url_download_bundle(self, bundle_name, url, hash_val): + def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): with skip_if_downloading_fails(): # download a single file from url, also use `args_file` with tempfile.TemporaryDirectory() as tempdir: @@ -83,9 +82,11 @@ def test_url_download_bundle(self, bundle_name, url, hash_val): cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] cmd += ["--url", url] subprocess.check_call(cmd) - file_path = os.path.join(tempdir, "network.json") - self.assertTrue(os.path.exists(file_path)) - self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + if file == "network.json": + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) class TestLoad(unittest.TestCase): @@ -95,15 +96,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device): with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: - # download bundle - download(name=bundle_name, repo=repo, bundle_dir=tempdir, progress=False) - - # load weights only - weights_name = bundle_name + "#" + bundle_files[0] - weights = load(name=weights_name, bundle_dir=tempdir, progress=False, device=device) + # load weights + weights = load(name=bundle_name, bundle_dir=tempdir, repo=repo, progress=False, device=device) # prepare network - with open(os.path.join(tempdir, bundle_files[2])) as f: + with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: net_args = json.load(f)["network_def"] model_name = net_args["_target_"] del net_args["_target_"] @@ -113,19 +110,15 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device): model.eval() # prepare data and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[4]), map_location=device) + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device) output = model.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[3]), map_location=device) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device) torch.testing.assert_allclose(output, expected_output) - # load instantiated model directly and test + # load instantiated model directly and test, since the bundle has been downloaded, + # there is no need to input `repo` model_2 = load( - name=weights_name, - bundle_dir=tempdir, - progress=False, - device=device, - net_name=model_name, - **net_args, + name=bundle_name, bundle_dir=tempdir, progress=False, device=device, net_name=model_name, **net_args ) model_2.eval() output_2 = model_2.forward(input_tensor) @@ -135,21 +128,17 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device): @skip_if_quick def test_load_ts_module(self, bundle_files, bundle_name, repo, device): with skip_if_downloading_fails(): - # load ts module, the module name is not included + # load ts module with tempfile.TemporaryDirectory() as tempdir: # load ts module model_ts = load( - name=bundle_name, is_ts_model=True, bundle_dir=tempdir, repo=repo, progress=False, device=device + name=bundle_name, load_ts_module=True, bundle_dir=tempdir, repo=repo, progress=False, device=device ) - # download input and output tensors - for file in bundle_files: - download_name = bundle_name + "#" + file - download(name=download_name, repo=repo, bundle_dir=tempdir, progress=False) # prepare and test - input_tensor = torch.load(os.path.join(tempdir, bundle_files[1]), map_location=device) + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) output = model_ts.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_files[0]), map_location=device) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device) torch.testing.assert_allclose(output, expected_output) From 415f5291d6d92737c0a5c6fc91ea37d558938c4c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 22 Apr 2022 16:11:42 +0800 Subject: [PATCH 17/20] [DLMED] restore Exception for test Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3ee86f83ad..8bf0b56977 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -30,7 +30,7 @@ from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") -exceptions, _ = optional_import("jsonschema", name="exceptions") +ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") requests_get, has_requests = optional_import("requests", name="get") @@ -413,13 +413,10 @@ def verify_metadata( try: # the rest key-values in the _args are for `validate` API validate(instance=metadata, schema=schema, **_args) - except exceptions.ValidationError as e: + except ValidationError as e: # as the error message is very long, only extract the key information logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return - except exceptions.SchemaError as e: - logger.info(str(e)) - return logger.info("metadata is verified with no error.") From dc5d7b4949677b92f937870dd2be8112245cc2ac Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 22 Apr 2022 17:37:17 +0800 Subject: [PATCH 18/20] update ts features Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 62 +++++++++++++++++++---------------- tests/test_bundle_download.py | 33 +++++++++++++++---- 2 files changed, 61 insertions(+), 34 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3ee86f83ad..20f13b9d5b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -25,12 +25,12 @@ from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike -from monai.data import save_net_with_metadata +from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") -exceptions, _ = optional_import("jsonschema", name="exceptions") +ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") requests_get, has_requests = optional_import("requests", name="get") @@ -136,6 +136,16 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres extractall(filepath=filepath, output_dir=download_path, has_base=True) +def _process_bundle_dir(bundle_dir: Optional[PathLike] = None): + if bundle_dir is None: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + return Path(bundle_dir) + + def download( name: Optional[str] = None, bundle_dir: Optional[PathLike] = None, @@ -191,13 +201,7 @@ def download( _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True ) - if bundle_dir_ is None: - get_dir, has_home = optional_import("torch.hub", name="get_dir") - if has_home: - bundle_dir_ = Path(get_dir()) / "bundle" - else: - raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - bundle_dir_ = Path(bundle_dir_) + bundle_dir_ = _process_bundle_dir(bundle_dir_) if url_ is not None: if name is not None: @@ -225,12 +229,16 @@ def load( source: str = "github", repo: Optional[str] = None, progress: bool = True, - device: str = "cpu", + device: Optional[str] = None, + model_file=None, + config_files: Sequence[str] = (), net_name: Optional[str] = None, **net_kwargs, ): """ Load model weights or TorchScript module of a bundle. + If loading a TorchScript module, the corresponding metadata dict, and extra files dict will be returned (please + check `monai.data.load_net_with_metadata` for more details). If the weights file does not exist locally, it will be downloaded first. The function can return weights, an instantiated network that loaded the weights, or a TorchScript module. @@ -245,30 +253,31 @@ def load( If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. progress: whether to display a progress bar when downloading. - device: target device of returned weights or module. + device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. + model_file: the relative path of the model weights or TorchScript module within bundle. + If `None`, "models/model.pt" or "models/model.ts" will be used. + config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, + see `_extra_files` in `torch.jit.load` for more details. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. This argument only works when loading weights. net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. """ - if bundle_dir is None: - get_dir, has_home = optional_import("torch.hub", name="get_dir") - if has_home: - bundle_dir = Path(get_dir()) / "bundle" - else: - raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - bundle_dir = Path(bundle_dir) + bundle_dir_ = _process_bundle_dir(bundle_dir) - weights_name: str = "model.ts" if load_ts_module is True else "model.pt" - model_file_path = os.path.join(bundle_dir, name, weights_name) - if not os.path.exists(model_file_path): - download(name=name, bundle_dir=bundle_dir, source=source, repo=repo, progress=progress) + if model_file is None: + model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + full_path = os.path.join(bundle_dir_, name, model_file) + if not os.path.exists(full_path): + download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress) + if device is None: + device = "cuda:0" if is_available() else "cpu" # loading with `torch.jit.load` if load_ts_module is True: - return torch.jit.load(model_file_path, map_location=device) + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) # loading with `torch.load` - model_dict = torch.load(model_file_path, map_location=device) + model_dict = torch.load(full_path, map_location=torch.device(device)) if net_name is None: return model_dict @@ -413,13 +422,10 @@ def verify_metadata( try: # the rest key-values in the _args are for `validate` API validate(instance=metadata, schema=schema, **_args) - except exceptions.ValidationError as e: + except ValidationError as e: # as the error message is very long, only extract the key information logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return - except exceptions.SchemaError as e: - logger.info(str(e)) - return logger.info("metadata is verified with no error.") diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 9a377c7cae..90510d0baa 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -42,6 +42,7 @@ "test_bundle", "Project-MONAI/MONAI-extra-test-data/0.8.1", "cuda" if torch.cuda.is_available() else "cpu", + "model.pt", ] TEST_CASE_4 = [ @@ -49,6 +50,7 @@ "test_bundle", "Project-MONAI/MONAI-extra-test-data/0.8.1", "cuda" if torch.cuda.is_available() else "cpu", + "model.ts", ] @@ -92,12 +94,19 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): class TestLoad(unittest.TestCase): @parameterized.expand([TEST_CASE_3]) @skip_if_quick - def test_load_weights(self, bundle_files, bundle_name, repo, device): + def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - weights = load(name=bundle_name, bundle_dir=tempdir, repo=repo, progress=False, device=device) + weights = load( + name=bundle_name, + bundle_dir=tempdir, + repo=repo, + progress=False, + device=device, + model_file=model_file, + ) # prepare network with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: @@ -118,7 +127,13 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device): # load instantiated model directly and test, since the bundle has been downloaded, # there is no need to input `repo` model_2 = load( - name=bundle_name, bundle_dir=tempdir, progress=False, device=device, net_name=model_name, **net_args + name=bundle_name, + bundle_dir=tempdir, + progress=False, + device=device, + model_file=model_file, + net_name=model_name, + **net_args, ) model_2.eval() output_2 = model_2.forward(input_tensor) @@ -126,14 +141,20 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device): @parameterized.expand([TEST_CASE_4]) @skip_if_quick - def test_load_ts_module(self, bundle_files, bundle_name, repo, device): + def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_file): with skip_if_downloading_fails(): # load ts module with tempfile.TemporaryDirectory() as tempdir: # load ts module model_ts = load( - name=bundle_name, load_ts_module=True, bundle_dir=tempdir, repo=repo, progress=False, device=device - ) + name=bundle_name, + load_ts_module=True, + bundle_dir=tempdir, + repo=repo, + model_file=model_file, + progress=False, + device=device, + )[0] # prepare and test input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) From c083b48b1550810789bf00a342e10f274cbe2bfc Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 22 Apr 2022 19:23:24 +0800 Subject: [PATCH 19/20] add config_files test case Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 18 +++++++++++------- tests/test_bundle_download.py | 18 ++++++++++++------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 20f13b9d5b..eef3a4e61d 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -224,26 +224,24 @@ def download( def load( name: str, + model_file: Optional[str] = None, load_ts_module: bool = False, bundle_dir: Optional[PathLike] = None, source: str = "github", repo: Optional[str] = None, progress: bool = True, device: Optional[str] = None, - model_file=None, config_files: Sequence[str] = (), net_name: Optional[str] = None, **net_kwargs, ): """ Load model weights or TorchScript module of a bundle. - If loading a TorchScript module, the corresponding metadata dict, and extra files dict will be returned (please - check `monai.data.load_net_with_metadata` for more details). - If the weights file does not exist locally, it will be downloaded first. - The function can return weights, an instantiated network that loaded the weights, or a TorchScript module. Args: name: bundle name. + model_file: the relative path of the model weights or TorchScript module within bundle. + If `None`, "models/model.pt" or "models/model.ts" will be used. load_ts_module: a flag to specify if loading the TorchScript module. bundle_dir: the directory the weights/TorchScript module will be loaded from. Default is `bundle` subfolder under`torch.hub get_dir()`. @@ -254,14 +252,20 @@ def load( For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. progress: whether to display a progress bar when downloading. device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. - model_file: the relative path of the model weights or TorchScript module within bundle. - If `None`, "models/model.pt" or "models/model.ts" will be used. config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, see `_extra_files` in `torch.jit.load` for more details. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. This argument only works when loading weights. net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. + Returns: + 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. + 2. If `load_ts_module` is `False` and `net_name` is not `None`, + return an instantiated network that loaded the weights. + 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + the corresponding metadata dict, and extra files dict. + please check `monai.data.load_net_with_metadata` for more details. + """ bundle_dir_ = _process_bundle_dir(bundle_dir) diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 90510d0baa..921399bc54 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -101,11 +101,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) # load weights weights = load( name=bundle_name, + model_file=model_file, bundle_dir=tempdir, repo=repo, progress=False, device=device, - model_file=model_file, ) # prepare network @@ -128,10 +128,10 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) # there is no need to input `repo` model_2 = load( name=bundle_name, + model_file=model_file, bundle_dir=tempdir, progress=False, device=device, - model_file=model_file, net_name=model_name, **net_args, ) @@ -146,21 +146,27 @@ def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_fil # load ts module with tempfile.TemporaryDirectory() as tempdir: # load ts module - model_ts = load( + model_ts, metadata, extra_file_dict = load( name=bundle_name, + model_file=model_file, load_ts_module=True, bundle_dir=tempdir, repo=repo, - model_file=model_file, progress=False, device=device, - )[0] + config_files=("test_config.txt",), + ) - # prepare and test + # prepare and test ts input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) output = model_ts.forward(input_tensor) expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device) torch.testing.assert_allclose(output, expected_output) + # test metadata + self.assertTrue(metadata["foo"] == [1, 2]) + self.assertTrue(metadata["bar"] == "string") + # test extra_file_dict + self.assertTrue("test_config.txt" in extra_file_dict.keys()) if __name__ == "__main__": From 062a8006a105b94eed80a2be6e098e083c55a1fa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 22 Apr 2022 22:53:48 +0800 Subject: [PATCH 20/20] enhance docstring example for args_file Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index eef3a4e61d..33affcf31b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -172,8 +172,11 @@ def download( python -m monai.bundle download --name "bundle_name" --url # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. - # Other args still can override the default args at runtime: - python -m monai.bundle download --args_file "/workspace/data/args.json" --source "github" + # Other args still can override the default args at runtime. + # The content of the JSON / YAML file is a dictionary. For example: + # {"name": "spleen", "bundle_dir": "download", "source": ""} + # then do the following command for downloading: + python -m monai.bundle download --args_file "args.json" --source "github" Args: name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.