Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9a6a80a
draft download
yiheng-wang-nv Apr 8, 2022
3b6d2cd
update bundle download
yiheng-wang-nv Apr 12, 2022
899dbd6
add url and load
yiheng-wang-nv Apr 14, 2022
b9fc7dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2022
981a990
rename args and remove a few places
yiheng-wang-nv Apr 14, 2022
24ea321
Merge branch 'add-bundle-download' of github.com:yiheng-wang-nv/MONAI…
yiheng-wang-nv Apr 14, 2022
2769ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2022
31e9280
fix flake8 issue
yiheng-wang-nv Apr 14, 2022
0d11dc9
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 14, 2022
d37279a
Merge branch 'add-bundle-download' of github.com:yiheng-wang-nv/MONAI…
yiheng-wang-nv Apr 14, 2022
95227b1
enhance with reviews
yiheng-wang-nv Apr 18, 2022
adf187a
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 18, 2022
0e53035
add instantiate for load
yiheng-wang-nv Apr 18, 2022
5cad064
fix black error
yiheng-wang-nv Apr 18, 2022
3ece567
add unittest
yiheng-wang-nv Apr 19, 2022
15dc4b6
add load to docs
yiheng-wang-nv Apr 19, 2022
12d400a
add skip
yiheng-wang-nv Apr 19, 2022
78bbfa6
add schemaerror
yiheng-wang-nv Apr 19, 2022
780ecb7
fix partial places
yiheng-wang-nv Apr 20, 2022
2b936d6
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 20, 2022
627ae48
download zip bundle
yiheng-wang-nv Apr 21, 2022
0e6062d
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 21, 2022
415f529
[DLMED] restore Exception for test
Nic-Ma Apr 22, 2022
dc5d7b4
update ts features
yiheng-wang-nv Apr 22, 2022
cb906fd
Merge branch 'add-bundle-download' of github.com:yiheng-wang-nv/MONAI…
yiheng-wang-nv Apr 22, 2022
8927d2b
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 22, 2022
c083b48
add config_files test case
yiheng-wang-nv Apr 22, 2022
284be7e
Merge branch 'add-bundle-download' of github.com:yiheng-wang-nv/MONAI…
yiheng-wang-nv Apr 22, 2022
062a800
enhance docstring example for args_file
yiheng-wang-nv Apr 22, 2022
c8cc140
Merge branch 'dev' into add-bundle-download
yiheng-wang-nv Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Model Bundle
`Scripts`
---------
.. autofunction:: ckpt_export
.. autofunction:: download
.. autofunction:: load
.. autofunction:: run
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
2 changes: 1 addition & 1 deletion monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out
from .scripts import ckpt_export, download, load, run, verify_metadata, verify_net_in_out
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
2 changes: 1 addition & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out
from monai.bundle.scripts import ckpt_export, download, run, verify_metadata, verify_net_in_out

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
184 changes: 182 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@

import ast
import json
import os
import pprint
import re
from logging.config import fileConfig
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
from torch.cuda import is_available

from monai.apps.utils import download_url, get_logger
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.config import IgniteInfo, PathLike
from monai.data import save_net_with_metadata
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -116,6 +120,182 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int
return tuple(ret)


def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str):
return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True):
if len(repo.split("/")) != 3:
raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.")
repo_owner, repo_name, tag_name = repo.split("/")
if ".zip" not in filename:
filename += ".zip"
url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename)
filepath = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=download_path, has_base=True)


def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
if bundle_dir is None:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
if has_home:
bundle_dir = Path(get_dir()) / "bundle"
else:
raise ValueError("bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")
return Path(bundle_dir)


def download(
name: Optional[str] = None,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
url: Optional[str] = None,
progress: bool = True,
args_file: Optional[str] = None,
):
"""
download bundle from the specified source or url. The bundle should be a zip file and it
will be extracted after downloading.
This function refers to:
https://pytorch.org/docs/stable/_modules/torch/hub.html

Typical usage examples:

.. code-block:: bash

# Execute this module as a CLI entry, and download bundle:
python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag"

# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name "bundle_name" --url <url>

# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime.
# The content of the JSON / YAML file is a dictionary. For example:
# {"name": "spleen", "bundle_dir": "download", "source": ""}
# then do the following command for downloading:
python -m monai.bundle download --args_file "args.json" --source "github"

Args:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
progress: whether to display a progress bar.
args_file: a JSON or YAML file to provide default values for all the args in this function.
so that the command line inputs can be simplified.

"""
_args = _update_args(
args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress
)

_log_input_summary(tag="download", args=_args)
name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args(
_args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True
)

bundle_dir_ = _process_bundle_dir(bundle_dir_)

if url_ is not None:
if name is not None:
filepath = bundle_dir_ / f"{name}.zip"
else:
filepath = bundle_dir_ / f"{_basename(url_)}"
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
elif source_ == "github":
if name_ is None or repo_ is None:
raise ValueError(
f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}."
)
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
else:
raise NotImplementedError(
f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}."
)


def load(
name: str,
model_file: Optional[str] = None,
load_ts_module: bool = False,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
progress: bool = True,
device: Optional[str] = None,
config_files: Sequence[str] = (),
net_name: Optional[str] = None,
**net_kwargs,
):
"""
Load model weights or TorchScript module of a bundle.

Args:
name: bundle name.
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
bundle_dir: the directory the weights/TorchScript module will be loaded from.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: the place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
see `_extra_files` in `torch.jit.load` for more details.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
This argument only works when loading weights.
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

Returns:
1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights.
2. If `load_ts_module` is `False` and `net_name` is not `None`,
return an instantiated network that loaded the weights.
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.

"""
bundle_dir_ = _process_bundle_dir(bundle_dir)

if model_file is None:
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path):
download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)

if device is None:
device = "cuda:0" if is_available() else "cpu"
# loading with `torch.jit.load`
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))

if net_name is None:
return model_dict
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate()
model.to(device) # type: ignore
model.load_state_dict(model_dict) # type: ignore
return model


def run(
runner_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
Expand Down
Loading