Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
c86d6b6
Merge pull request #388 from Project-MONAI/dev
Nic-Ma Mar 17, 2022
59a7524
[DLMED] add export script
Nic-Ma Mar 17, 2022
87e87ca
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 17, 2022
8d54283
[DLMED] add base unit test
Nic-Ma Mar 17, 2022
ddf1c8d
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 17, 2022
6d4b7e4
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 18, 2022
b750b73
[DLMED] add custom config item types and reference resolver
Nic-Ma Mar 18, 2022
36d8469
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
1fbd385
[DLMED] enhance doc-string
Nic-Ma Mar 18, 2022
81ac67e
Merge branch '3482-add-export-script' of https://github.com/Nic-Ma/MO…
Nic-Ma Mar 18, 2022
5c23353
[DLMED] fix typo
Nic-Ma Mar 18, 2022
b5b9f20
[DLMED] add logging
Nic-Ma Mar 18, 2022
a90d633
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
8ababd8
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 21, 2022
e479b14
Merge branch 'dev' into 3482-add-export-script
wyli Mar 21, 2022
efd0b10
[DLMED] remove customized configitem and resolver
Nic-Ma Mar 22, 2022
5498c53
[DLMED] remove assert ret
Nic-Ma Mar 22, 2022
c87e458
[DLMED] update to ckpt_export
Nic-Ma Mar 22, 2022
2344abb
[DLMED] update to .ts
Nic-Ma Mar 22, 2022
a76843c
[DLMED] update min test
Nic-Ma Mar 22, 2022
db98006
[DLMED] simplify args
Nic-Ma Mar 22, 2022
1480009
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 22, 2022
700a9d9
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 22, 2022
f352bd1
fixes integration tests
wyli Mar 23, 2022
a157472
Merge branch 'dev' into 3482-add-export-script
wyli Mar 23, 2022
26f3ec4
[DLMED] change to coverage
Nic-Ma Mar 23, 2022
fce2bd7
update integration tests
wyli Mar 23, 2022
08ca004
Merge branch 'dev' into 3482-add-export-script
wyli Mar 23, 2022
1db8f00
[DLMED] fix blossom
Nic-Ma Mar 24, 2022
84b2085
Merge branch 'dev' into 3482-add-export-script
Nic-Ma Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Model Bundle

`Scripts`
---------
.. autofunction:: ckpt_export
.. autofunction:: run
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
2 changes: 1 addition & 1 deletion monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
from .scripts import run, verify_metadata, verify_net_in_out
from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
2 changes: 1 addition & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from monai.bundle.scripts import run, verify_metadata, verify_net_in_out
from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
147 changes: 117 additions & 30 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,24 @@
# limitations under the License.

import ast
import json
import pprint
import re
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
from torch.cuda import is_available

from monai.apps.utils import download_url, get_logger
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import
from monai.config import IgniteInfo, PathLike
from monai.data import save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -54,6 +59,14 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr
return args_


def _pop_args(src: Dict, *args, **kwargs):
"""
Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`.

"""
return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()])


def _log_input_summary(tag, args: Dict):
logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
for name, val in args.items():
Expand Down Expand Up @@ -148,20 +161,22 @@ def run(
if "config_file" not in _args:
raise ValueError(f"`config_file` is required for 'monai.bundle run'.\n{run.__doc__}")
_log_input_summary(tag="run", args=_args)
config_file_, meta_file_, runner_id_ = _pop_args(_args, "config_file", meta_file=None, runner_id="")

parser = ConfigParser()
parser.read_config(f=_args.pop("config_file"))
if "meta_file" in _args:
parser.read_meta(f=_args.pop("meta_file"))
id = _args.pop("runner_id", "")
parser.read_config(f=config_file_)
if meta_file_ is not None:
parser.read_meta(f=meta_file_)

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

workflow = parser.get_parsed_content(id=id)
workflow = parser.get_parsed_content(id=runner_id_)
if not hasattr(workflow, "run"):
raise ValueError(f"The parsed workflow {type(workflow)} (id={id}) does not have a `run` method.\n{run.__doc__}")
raise ValueError(
f"The parsed workflow {type(workflow)} (id={runner_id_}) does not have a `run` method.\n{run.__doc__}"
)
return workflow.run()


Expand Down Expand Up @@ -203,22 +218,16 @@ def verify_metadata(
**kwargs,
)
_log_input_summary(tag="verify_metadata", args=_args)
filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args(
_args, "filepath", "meta_file", create_dir=True, hash_val=None, hash_type="md5"
)

filepath_ = _args.pop("filepath")
create_dir_ = _args.pop("create_dir", True)
check_parent_dir(path=filepath_, create_dir=create_dir_)

metadata = ConfigParser.load_config_files(files=_args.pop("meta_file"))
metadata = ConfigParser.load_config_files(files=meta_file_)
url = metadata.get("schema")
if url is None:
raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.")
download_url(
url=url,
filepath=filepath_,
hash_val=_args.pop("hash_val", None),
hash_type=_args.pop("hash_type", "md5"),
progress=True,
)
download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True)
schema = ConfigParser.load_config_file(filepath=filepath_)

try:
Expand Down Expand Up @@ -262,8 +271,8 @@ def verify_net_in_out(
p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`net_id` and override pairs. so that the command line inputs can be simplified.
args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`,
`device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

Expand All @@ -281,22 +290,20 @@ def verify_net_in_out(
**override,
)
_log_input_summary(tag="verify_net_in_out", args=_args)
config_file_, meta_file_, net_id_, device_, p_, n_, any_ = _pop_args(
_args, "config_file", "meta_file", net_id="", device="cuda:0" if is_available() else "cpu", p=1, n=1, any=1
)

parser = ConfigParser()
parser.read_config(f=_args.pop("config_file"))
parser.read_meta(f=_args.pop("meta_file"))
id = _args.pop("net_id", "")
device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu"))
p = _args.pop("p", 1)
n = _args.pop("n", 1)
any = _args.pop("any", 1)
parser.read_config(f=config_file_)
parser.read_meta(f=meta_file_)

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

try:
key: str = id # mark the full id when KeyError
key: str = net_id_ # mark the full id when KeyError
net = parser.get_parsed_content(key).to(device_)
key = "_meta_#network_data_format#inputs#image#num_channels"
input_channels = parser[key]
Expand All @@ -313,11 +320,91 @@ def verify_net_in_out(

net.eval()
with torch.no_grad():
spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore
spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_) # type: ignore
test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_)
output = net(test_data)
if output.shape[1] != output_channels:
raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.")
if output.dtype != output_dtype:
raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.")
logger.info("data shape of network is verified with no error.")


def ckpt_export(
net_id: Optional[str] = None,
filepath: Optional[PathLike] = None,
ckpt_file: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
config_file: Optional[Union[str, Sequence[str]]] = None,
key_in_ckpt: Optional[str] = None,
args_file: Optional[str] = None,
**override,
):
"""
Export the model checkpoint to the given filepath with metadata and config included as JSON files.

Typical usage examples:

.. code-block:: bash

python -m monai.bundle export network --filepath <export path> --ckpt_file <checkpoint path> ...

Args:
net_id: ID name of the network component in the config, it must be `torch.nn.Module`.
filepath: filepath to export, if filename has no extension it becomes `.ts`.
ckpt_file: filepath of the model checkpoint to load.
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
weights. if not nested checkpoint, no need to set.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`net_id` and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

"""
_args = _update_args(
args=args_file,
net_id=net_id,
filepath=filepath,
meta_file=meta_file,
config_file=config_file,
ckpt_file=ckpt_file,
key_in_ckpt=key_in_ckpt,
**override,
)
_log_input_summary(tag="export", args=_args)
filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args(
_args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt=""
)

parser = ConfigParser()

parser.read_config(f=config_file_)
if meta_file_ is not None:
parser.read_meta(f=meta_file_)

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

net = parser.get_parsed_content(net_id_)
if has_ignite:
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
else:
copy_model_state(dst=net, src=ckpt_file_ if key_in_ckpt_ == "" else ckpt_file_[key_in_ckpt_])

# convert to TorchScript model and save with meta data, config content
net = convert_to_torchscript(model=net)

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath_,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files={"config": json.dumps(parser.get()).encode()},
)
logger.info(f"exported to TorchScript file: {filepath_}.")
6 changes: 3 additions & 3 deletions monai/data/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def save_net_with_metadata(
save_net_with_metadata(m, "test", meta_values=meta)

# load the network back, `loaded_meta` has same data as `meta` plus version information
loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt")
loaded_net, loaded_meta, _ = load_net_with_metadata("test.ts")


Args:
jit_obj: object to save, should be generated by `script` or `trace`.
filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`.
filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`.
include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata.
append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension.
meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`.
Expand Down Expand Up @@ -97,7 +97,7 @@ def save_net_with_metadata(
if isinstance(filename_prefix_or_stream, str):
filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream)
if ext == "":
ext = ".pt"
ext = ".ts"

if append_timestamp:
filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}")
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def run_testsuit():
"test_parallel_execution_dist",
"test_bundle_verify_metadata",
"test_bundle_verify_net",
"test_bundle_ckpt_export",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
54 changes: 54 additions & 0 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import tempfile
import unittest

from parameterized import parameterized

from monai.bundle import ConfigParser
from monai.networks import save_state
from tests.utils import skip_if_windows

TEST_CASE_1 = [""]

TEST_CASE_2 = ["model"]


@skip_if_windows
class TestCKPTExport(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_export(self, key_in_ckpt):
meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json")
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.json")
ckpt_file = os.path.join(tempdir, "model.pt")
ts_file = os.path.join(tempdir, "model.ts")

parser = ConfigParser()
parser.export_config_file(config=def_args, filepath=def_args_file)
parser.read_config(config_file)
net = parser.get_parsed_content("network_def")
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file]
cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
subprocess.check_call(cmd)
self.assertTrue(os.path.exists(ts_file))


if __name__ == "__main__":
unittest.main()
14 changes: 4 additions & 10 deletions tests/test_bundle_verify_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
# limitations under the License.

import json
import logging
import os
import subprocess
import sys
import tempfile
import unittest

Expand All @@ -40,29 +38,25 @@ def setUp(self):

@parameterized.expand([TEST_CASE_1])
def test_verify(self, meta_file, schema_file):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.json")
ConfigParser.export_config_file(config=def_args, filepath=def_args_file)

cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file]
ret = subprocess.check_call(cmd)
self.assertEqual(ret, 0)
subprocess.check_call(cmd)

def test_verify_error(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "schema.json")
metafile = os.path.join(tempdir, "metadata.json")
meta_dict = {"schema": self.config["url"], "wrong_meta": "wrong content"}
with open(metafile, "w") as f:
json.dump(meta_dict, f)

cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath]
ret = subprocess.check_call(cmd)
self.assertEqual(ret, 0)
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath]
subprocess.check_call(cmd)


if __name__ == "__main__":
Expand Down
Loading