diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 297409cd7e..a28db04091 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -36,6 +36,8 @@ Model Bundle `Scripts` --------- .. autofunction:: ckpt_export +.. autofunction:: download +.. autofunction:: load .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index d6a452b5a4..f30ee9c40c 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, download, load, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index d77b396e79..3e3534ef74 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b741e40e8d..33affcf31b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -11,24 +11,28 @@ import ast import json +import os import pprint import re from logging.config import fileConfig +from pathlib import Path from typing import Dict, Optional, Sequence, Tuple, Union import torch from torch.cuda import is_available -from monai.apps.utils import download_url, get_logger +from monai.apps.utils import _basename, download_url, extractall, get_logger +from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike -from monai.data import save_net_with_metadata +from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +requests_get, has_requests = optional_import("requests", name="get") logger = get_logger(module_name=__name__) @@ -116,6 +120,182 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int return tuple(ret) +def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str): + return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" + + +def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True): + if len(repo.split("/")) != 3: + raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.") + repo_owner, repo_name, tag_name = repo.split("/") + if ".zip" not in filename: + filename += ".zip" + url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) + filepath = download_path / f"{filename}" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + +def _process_bundle_dir(bundle_dir: Optional[PathLike] = None): + if bundle_dir is None: + get_dir, has_home = optional_import("torch.hub", name="get_dir") + if has_home: + bundle_dir = Path(get_dir()) / "bundle" + else: + raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + return Path(bundle_dir) + + +def download( + name: Optional[str] = None, + bundle_dir: Optional[PathLike] = None, + source: str = "github", + repo: Optional[str] = None, + url: Optional[str] = None, + progress: bool = True, + args_file: Optional[str] = None, +): + """ + download bundle from the specified source or url. The bundle should be a zip file and it + will be extracted after downloading. + This function refers to: + https://pytorch.org/docs/stable/_modules/torch/hub.html + + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry, and download bundle: + python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag" + + # Execute this module as a CLI entry, and download bundle via URL: + python -m monai.bundle download --name "bundle_name" --url + + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. + # Other args still can override the default args at runtime. + # The content of the JSON / YAML file is a dictionary. For example: + # {"name": "spleen", "bundle_dir": "download", "source": ""} + # then do the following command for downloading: + python -m monai.bundle download --args_file "args.json" --source "github" + + Args: + name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. + bundle_dir: target directory to store the downloaded data. + Default is `bundle` subfolder under`torch.hub get_dir()`. + source: place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. + repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + url: url to download the data. If not `None`, data will be downloaded directly + and `source` will not be checked. + If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. + progress: whether to display a progress bar. + args_file: a JSON or YAML file to provide default values for all the args in this function. + so that the command line inputs can be simplified. + + """ + _args = _update_args( + args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress + ) + + _log_input_summary(tag="download", args=_args) + name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( + _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True + ) + + bundle_dir_ = _process_bundle_dir(bundle_dir_) + + if url_ is not None: + if name is not None: + filepath = bundle_dir_ / f"{name}.zip" + else: + filepath = bundle_dir_ / f"{_basename(url_)}" + download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) + extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) + elif source_ == "github": + if name_ is None or repo_ is None: + raise ValueError( + f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}." + ) + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + else: + raise NotImplementedError( + f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}." + ) + + +def load( + name: str, + model_file: Optional[str] = None, + load_ts_module: bool = False, + bundle_dir: Optional[PathLike] = None, + source: str = "github", + repo: Optional[str] = None, + progress: bool = True, + device: Optional[str] = None, + config_files: Sequence[str] = (), + net_name: Optional[str] = None, + **net_kwargs, +): + """ + Load model weights or TorchScript module of a bundle. + + Args: + name: bundle name. + model_file: the relative path of the model weights or TorchScript module within bundle. + If `None`, "models/model.pt" or "models/model.ts" will be used. + load_ts_module: a flag to specify if loading the TorchScript module. + bundle_dir: the directory the weights/TorchScript module will be loaded from. + Default is `bundle` subfolder under`torch.hub get_dir()`. + source: the place that saved the bundle. + If `source` is `github`, the bundle should be within the releases. + repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided. + If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. + For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + progress: whether to display a progress bar when downloading. + device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. + config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, + see `_extra_files` in `torch.jit.load` for more details. + net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + This argument only works when loading weights. + net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. + + Returns: + 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. + 2. If `load_ts_module` is `False` and `net_name` is not `None`, + return an instantiated network that loaded the weights. + 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + the corresponding metadata dict, and extra files dict. + please check `monai.data.load_net_with_metadata` for more details. + + """ + bundle_dir_ = _process_bundle_dir(bundle_dir) + + if model_file is None: + model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + full_path = os.path.join(bundle_dir_, name, model_file) + if not os.path.exists(full_path): + download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress) + + if device is None: + device = "cuda:0" if is_available() else "cpu" + # loading with `torch.jit.load` + if load_ts_module is True: + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) + # loading with `torch.load` + model_dict = torch.load(full_path, map_location=torch.device(device)) + + if net_name is None: + return model_dict + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() + model.to(device) # type: ignore + model.load_state_dict(model_dict) # type: ignore + return model + + def run( runner_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py new file mode 100644 index 0000000000..921399bc54 --- /dev/null +++ b/tests/test_bundle_download.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import tempfile +import unittest + +import torch +from parameterized import parameterized + +import monai.networks.nets as nets +from monai.apps import check_hash +from monai.bundle import ConfigParser, load +from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows + +TEST_CASE_1 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_2 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip", + "a131d39a0af717af32d19e565b434928", +] + +TEST_CASE_3 = [ + ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "cuda" if torch.cuda.is_available() else "cpu", + "model.pt", +] + +TEST_CASE_4 = [ + ["test_output.pt", "test_input.pt"], + "test_bundle", + "Project-MONAI/MONAI-extra-test-data/0.8.1", + "cuda" if torch.cuda.is_available() else "cpu", + "model.ts", +] + + +@skip_if_windows +class TestDownload(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + @skip_if_quick + def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val): + with skip_if_downloading_fails(): + # download a whole bundle from github releases + with tempfile.TemporaryDirectory() as tempdir: + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"] + cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + if file == "network.json": + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + @parameterized.expand([TEST_CASE_2]) + @skip_if_quick + def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val): + with skip_if_downloading_fails(): + # download a single file from url, also use `args_file` + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"name": bundle_name, "bundle_dir": tempdir, "url": ""} + def_args_file = os.path.join(tempdir, "def_args.json") + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file] + cmd += ["--url", url] + subprocess.check_call(cmd) + for file in bundle_files: + file_path = os.path.join(tempdir, bundle_name, file) + self.assertTrue(os.path.exists(file_path)) + if file == "network.json": + self.assertTrue(check_hash(filepath=file_path, val=hash_val)) + + +class TestLoad(unittest.TestCase): + @parameterized.expand([TEST_CASE_3]) + @skip_if_quick + def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): + with skip_if_downloading_fails(): + # download bundle, and load weights from the downloaded path + with tempfile.TemporaryDirectory() as tempdir: + # load weights + weights = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + repo=repo, + progress=False, + device=device, + ) + + # prepare network + with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: + net_args = json.load(f)["network_def"] + model_name = net_args["_target_"] + del net_args["_target_"] + model = nets.__dict__[model_name](**net_args) + model.to(device) + model.load_state_dict(weights) + model.eval() + + # prepare data and test + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device) + output = model.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device) + torch.testing.assert_allclose(output, expected_output) + + # load instantiated model directly and test, since the bundle has been downloaded, + # there is no need to input `repo` + model_2 = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + progress=False, + device=device, + net_name=model_name, + **net_args, + ) + model_2.eval() + output_2 = model_2.forward(input_tensor) + torch.testing.assert_allclose(output_2, expected_output) + + @parameterized.expand([TEST_CASE_4]) + @skip_if_quick + def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_file): + with skip_if_downloading_fails(): + # load ts module + with tempfile.TemporaryDirectory() as tempdir: + # load ts module + model_ts, metadata, extra_file_dict = load( + name=bundle_name, + model_file=model_file, + load_ts_module=True, + bundle_dir=tempdir, + repo=repo, + progress=False, + device=device, + config_files=("test_config.txt",), + ) + + # prepare and test ts + input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) + output = model_ts.forward(input_tensor) + expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device) + torch.testing.assert_allclose(output, expected_output) + # test metadata + self.assertTrue(metadata["foo"] == [1, 2]) + self.assertTrue(metadata["bar"] == "string") + # test extra_file_dict + self.assertTrue("test_config.txt" in extra_file_dict.keys()) + + +if __name__ == "__main__": + unittest.main()