diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 87c4bf36d2..297409cd7e 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -35,6 +35,7 @@ Model Bundle `Scripts` --------- +.. autofunction:: ckpt_export .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 72c8805e9f..d6a452b5a4 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 0ff0a476ef..d77b396e79 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5bbde5fd62..13a5a19623 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -10,19 +10,24 @@ # limitations under the License. import ast +import json import pprint import re from typing import Dict, Optional, Sequence, Tuple, Union import torch +from torch.cuda import is_available from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser -from monai.config import PathLike -from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import +from monai.config import IgniteInfo, PathLike +from monai.data import save_net_with_metadata +from monai.networks import convert_to_torchscript, copy_model_state +from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") +Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") logger = get_logger(module_name=__name__) @@ -54,6 +59,14 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr return args_ +def _pop_args(src: Dict, *args, **kwargs): + """ + Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`. + + """ + return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) + + def _log_input_summary(tag, args: Dict): logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): @@ -148,20 +161,22 @@ def run( if "config_file" not in _args: raise ValueError(f"`config_file` is required for 'monai.bundle run'.\n{run.__doc__}") _log_input_summary(tag="run", args=_args) + config_file_, meta_file_, runner_id_ = _pop_args(_args, "config_file", meta_file=None, runner_id="") parser = ConfigParser() - parser.read_config(f=_args.pop("config_file")) - if "meta_file" in _args: - parser.read_meta(f=_args.pop("meta_file")) - id = _args.pop("runner_id", "") + parser.read_config(f=config_file_) + if meta_file_ is not None: + parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v - workflow = parser.get_parsed_content(id=id) + workflow = parser.get_parsed_content(id=runner_id_) if not hasattr(workflow, "run"): - raise ValueError(f"The parsed workflow {type(workflow)} (id={id}) does not have a `run` method.\n{run.__doc__}") + raise ValueError( + f"The parsed workflow {type(workflow)} (id={runner_id_}) does not have a `run` method.\n{run.__doc__}" + ) return workflow.run() @@ -203,22 +218,16 @@ def verify_metadata( **kwargs, ) _log_input_summary(tag="verify_metadata", args=_args) + filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args( + _args, "filepath", "meta_file", create_dir=True, hash_val=None, hash_type="md5" + ) - filepath_ = _args.pop("filepath") - create_dir_ = _args.pop("create_dir", True) check_parent_dir(path=filepath_, create_dir=create_dir_) - - metadata = ConfigParser.load_config_files(files=_args.pop("meta_file")) + metadata = ConfigParser.load_config_files(files=meta_file_) url = metadata.get("schema") if url is None: raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.") - download_url( - url=url, - filepath=filepath_, - hash_val=_args.pop("hash_val", None), - hash_type=_args.pop("hash_type", "md5"), - progress=True, - ) + download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True) schema = ConfigParser.load_config_file(filepath=filepath_) try: @@ -262,8 +271,8 @@ def verify_net_in_out( p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. - args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, - `net_id` and override pairs. so that the command line inputs can be simplified. + args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`, + `device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. @@ -281,22 +290,20 @@ def verify_net_in_out( **override, ) _log_input_summary(tag="verify_net_in_out", args=_args) + config_file_, meta_file_, net_id_, device_, p_, n_, any_ = _pop_args( + _args, "config_file", "meta_file", net_id="", device="cuda:0" if is_available() else "cpu", p=1, n=1, any=1 + ) parser = ConfigParser() - parser.read_config(f=_args.pop("config_file")) - parser.read_meta(f=_args.pop("meta_file")) - id = _args.pop("net_id", "") - device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu")) - p = _args.pop("p", 1) - n = _args.pop("n", 1) - any = _args.pop("any", 1) + parser.read_config(f=config_file_) + parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v try: - key: str = id # mark the full id when KeyError + key: str = net_id_ # mark the full id when KeyError net = parser.get_parsed_content(key).to(device_) key = "_meta_#network_data_format#inputs#image#num_channels" input_channels = parser[key] @@ -313,7 +320,7 @@ def verify_net_in_out( net.eval() with torch.no_grad(): - spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_) # type: ignore test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) output = net(test_data) if output.shape[1] != output_channels: @@ -321,3 +328,83 @@ def verify_net_in_out( if output.dtype != output_dtype: raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.") logger.info("data shape of network is verified with no error.") + + +def ckpt_export( + net_id: Optional[str] = None, + filepath: Optional[PathLike] = None, + ckpt_file: Optional[str] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + key_in_ckpt: Optional[str] = None, + args_file: Optional[str] = None, + **override, +): + """ + Export the model checkpoint to the given filepath with metadata and config included as JSON files. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle export network --filepath --ckpt_file ... + + Args: + net_id: ID name of the network component in the config, it must be `torch.nn.Module`. + filepath: filepath to export, if filename has no extension it becomes `.ts`. + ckpt_file: filepath of the model checkpoint to load. + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model + weights. if not nested checkpoint, no need to set. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `net_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. + + """ + _args = _update_args( + args=args_file, + net_id=net_id, + filepath=filepath, + meta_file=meta_file, + config_file=config_file, + ckpt_file=ckpt_file, + key_in_ckpt=key_in_ckpt, + **override, + ) + _log_input_summary(tag="export", args=_args) + filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args( + _args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt="" + ) + + parser = ConfigParser() + + parser.read_config(f=config_file_) + if meta_file_ is not None: + parser.read_meta(f=meta_file_) + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + net = parser.get_parsed_content(net_id_) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) + else: + copy_model_state(dst=net, src=ckpt_file_ if key_in_ckpt_ == "" else ckpt_file_[key_in_ckpt_]) + + # convert to TorchScript model and save with meta data, config content + net = convert_to_torchscript(model=net) + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream=filepath_, + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files={"config": json.dumps(parser.get()).encode()}, + ) + logger.info(f"exported to TorchScript file: {filepath_}.") diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 585db14712..61477e8ca9 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -56,12 +56,12 @@ def save_net_with_metadata( save_net_with_metadata(m, "test", meta_values=meta) # load the network back, `loaded_meta` has same data as `meta` plus version information - loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + loaded_net, loaded_meta, _ = load_net_with_metadata("test.ts") Args: jit_obj: object to save, should be generated by `script` or `trace`. - filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`. include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. @@ -97,7 +97,7 @@ def save_net_with_metadata( if isinstance(filename_prefix_or_stream, str): filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) if ext == "": - ext = ".pt" + ext = ".ts" if append_timestamp: filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") diff --git a/tests/min_tests.py b/tests/min_tests.py index c0d4f36430..9bf95f3f49 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -162,6 +162,7 @@ def run_testsuit(): "test_parallel_execution_dist", "test_bundle_verify_metadata", "test_bundle_verify_net", + "test_bundle_ckpt_export", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py new file mode 100644 index 0000000000..0f7d0f7d35 --- /dev/null +++ b/tests/test_bundle_ckpt_export.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from monai.networks import save_state +from tests.utils import skip_if_windows + +TEST_CASE_1 = [""] + +TEST_CASE_2 = ["model"] + + +@skip_if_windows +class TestCKPTExport(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_export(self, key_in_ckpt): + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.json") + ckpt_file = os.path.join(tempdir, "model.pt") + ts_file = os.path.join(tempdir, "model.ts") + + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + parser.read_config(config_file) + net = parser.get_parsed_content("network_def") + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) + + cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] + cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] + cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] + subprocess.check_call(cmd) + self.assertTrue(os.path.exists(ts_file)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index b018c9a568..c816c081eb 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -10,10 +10,8 @@ # limitations under the License. import json -import logging import os import subprocess -import sys import tempfile import unittest @@ -40,19 +38,16 @@ def setUp(self): @parameterized.expand([TEST_CASE_1]) def test_verify(self, meta_file, schema_file): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg"} def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + subprocess.check_call(cmd) def test_verify_error(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) with tempfile.TemporaryDirectory() as tempdir: filepath = os.path.join(tempdir, "schema.json") metafile = os.path.join(tempdir, "metadata.json") @@ -60,9 +55,8 @@ def test_verify_error(self): with open(metafile, "w") as f: json.dump(meta_dict, f) - cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] + subprocess.check_call(cmd) if __name__ == "__main__": diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 62f99aab99..33d480d83f 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -11,7 +11,6 @@ import os import subprocess -import sys import tempfile import unittest @@ -35,14 +34,13 @@ def test_verify(self, meta_file, config_file): def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] - cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"] + cmd += [meta_file, "--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) - ret = subprocess.check_call(cmd, env=test_env) - self.assertEqual(ret, 0) + subprocess.check_call(cmd, env=test_env) if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index e6d4dfd89f..af97d9e9ad 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -49,9 +49,8 @@ def test_tiny(self): config_file = os.path.join(self.data_dir, "tiny_config.json") with open(config_file, "w") as f: json.dump({"": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}}, f) - cmd = [sys.executable, "-m", "monai.bundle", "run", "--config_file", config_file] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + cmd = ["coverage", "run", "-m", "monai.bundle", "run", "--config_file", config_file] + subprocess.check_call(cmd) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, config_file, expected_shape): @@ -89,19 +88,17 @@ def test_shape(self, config_file, expected_shape): override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" # test with `monai.bundle` as CLI entry directly cmd = f"-m monai.bundle run evaluator --postprocessing#transforms#2#output_postfix seg {override}" - la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) - ret = subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) - self.assertEqual(ret, 0) + subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) # here test the script with `google fire` tool as CLI cmd = "-m fire monai.bundle.scripts run --runner_id evaluator" cmd += f" --evaluator#amp False {override}" - la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] - ret = subprocess.check_call(la, env=test_env) - self.assertEqual(ret, 0) + la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + subprocess.check_call(la, env=test_env) self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index b26d41345a..d6bea09ed6 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -34,7 +34,7 @@ def test_save_net_with_metadata(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test") - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) def test_save_net_with_metadata_ext(self): """Save a network without metadata to a file.""" @@ -54,7 +54,7 @@ def test_save_net_with_metadata_with_extra(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) def test_load_net_with_metadata(self): """Save then load a network with no metadata or other extra files.""" @@ -62,7 +62,7 @@ def test_load_net_with_metadata(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test") - _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.ts") del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be @@ -77,7 +77,7 @@ def test_load_net_with_metadata_with_extra(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) - _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.ts") del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be @@ -98,9 +98,9 @@ def test_save_load_more_extra_files(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) - _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.ts", more_extra_files=("test.txt",)) if pytorch_after(1, 7): self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"])