From 59a75243d4522384ccc23c4106e054f57efbd7d7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Mar 2022 12:39:34 +0800 Subject: [PATCH 01/18] [DLMED] add export script Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 67 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1f3165dee3..e655616243 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -9,17 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import pprint import re from typing import Dict, Optional, Sequence, Union +import torch from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser -from monai.config import PathLike -from monai.utils import check_parent_dir, optional_import +from monai.config import PathLike, IgniteInfo + +from monai.data import save_net_with_metadata +from monai.networks import convert_to_torchscript, copy_model_state +from monai.utils import check_parent_dir, min_version, optional_import validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") +Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") logger = get_logger(module_name=__name__) @@ -172,3 +178,60 @@ def verify_metadata( logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return logger.info("metadata is verified with no error.") + + +def export( + net_id: Optional[str] = None, + filepath: Optional[PathLike] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + ckpt_file: Optional[str] = None, + key_in_ckpt: Optional[str] = None, + args_file: Optional[str] = None, + **override, +): + _args = _update_args( + args=args_file, + net_id=net_id, + filepath=filepath, + meta_file=meta_file, + config_file=config_file, + ckpt_file=ckpt_file, + key_in_ckpt=key_in_ckpt, + **override, + ) + _log_input_summary(tag="export", args=_args) + + parser = ConfigParser() + config_file_ = _args.pop("config_file") + parser.read_config(f=config_file_) + meta_file = _args.pop("meta_file") + if meta_file is not None: + parser.read_meta(f=meta_file) + id = _args.pop("net_id", "") + path = _args.pop("filepath") + ckpt = torch.load(_args.pop("ckpt_file")) + key = _args.pop("key_in_ckpt", "") + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + net = parser.get_parsed_content(id) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key: net}, checkpoint=ckpt) + else: + copy_model_state(dst=net, src=ckpt if key == "" else ckpt[key]) + + # convert to TorchScript model and save with meta data, config content + net = convert_to_torchscript(model=net) + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream=path, + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files={"config": json.dumps(parser.get()).encode()}, + ) From 8d542837d596238c8322ad1c9d41896a70335b0f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Mar 2022 00:48:17 +0800 Subject: [PATCH 02/18] [DLMED] add base unit test Signed-off-by: Nic Ma --- docs/source/bundle.rst | 1 + monai/bundle/__init__.py | 2 +- monai/bundle/__main__.py | 2 +- tests/min_tests.py | 1 + tests/test_bundle_export.py | 55 +++++++++++++++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 tests/test_bundle_export.py diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 87c4bf36d2..7b6a95dcbb 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -35,6 +35,7 @@ Model Bundle `Scripts` --------- +.. autofunction:: export .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 72c8805e9f..6cccbf4002 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import run, verify_metadata, verify_net_in_out +from .scripts import export, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 0ff0a476ef..6a754bc985 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import export, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/tests/min_tests.py b/tests/min_tests.py index c0d4f36430..3d93d96dda 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -162,6 +162,7 @@ def run_testsuit(): "test_parallel_execution_dist", "test_bundle_verify_metadata", "test_bundle_verify_net", + "test_bundle_export", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_bundle_export.py b/tests/test_bundle_export.py new file mode 100644 index 0000000000..99c4a4e225 --- /dev/null +++ b/tests/test_bundle_export.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from monai.networks import save_state +from tests.utils import skip_if_windows + +TEST_CASE_1 = [ + os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), + os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), +] + + +@skip_if_windows +class TestExport(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_export(self, meta_file, config_file): + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.json") + ckpt_file = os.path.join(tempdir, "model.pt") + ts_file = os.path.join(tempdir, "model.ts") + + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + parser.read_config(config_file) + net = parser.get_parsed_content("network_def") + save_state(src=net, path=ckpt_file) + + cmd = [sys.executable, "-m", "monai.bundle", "export", "network_def", "--filepath", ts_file] + cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] + cmd += ["--args_file", def_args_file] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + self.assertTrue(os.path.exists(ts_file)) + + +if __name__ == "__main__": + unittest.main() From b750b73defaa2c732bdf7de22b2e18b02c6ff20f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Mar 2022 11:53:24 +0800 Subject: [PATCH 03/18] [DLMED] add custom config item types and reference resolver Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 34 ++++++++++++++++++++++++++-------- tests/test_bundle_export.py | 17 +++++++++-------- tests/test_config_parser.py | 18 ++++++++++++++---- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 6fa7b3a2a2..a5d7da6996 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -14,7 +14,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver @@ -93,6 +93,8 @@ def __init__( config: Any = None, excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict[str, Any]] = None, + item_types: Optional[Union[Sequence[Type[ConfigItem]], Type[ConfigItem]]] = None, + resolver: Optional[ReferenceResolver] = None, ): self.config = None self.globals: Dict[str, Any] = {} @@ -100,9 +102,17 @@ def __init__( if globals is not None: for k, v in globals.items(): self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v + self.item_types = ( + (ConfigComponent, ConfigExpression, ConfigItem) if item_types is None else ensure_tuple(item_types) + ) self.locator = ComponentLocator(excludes=excludes) - self.ref_resolver = ReferenceResolver() + if resolver is not None: + if not isinstance(resolver, ReferenceResolver): + raise TypeError(f"resolver must be subclass of ReferenceResolver, but got: {type(resolver)}.") + self.ref_resolver = resolver + else: + self.ref_resolver = ReferenceResolver() if config is None: config = {self.meta_key: {}} self.set(config=config) @@ -292,12 +302,20 @@ def _do_parse(self, config, id: str = ""): # copy every config item to make them independent and add them to the resolver item_conf = deepcopy(config) - if ConfigComponent.is_instantiable(item_conf): - self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator)) - elif ConfigExpression.is_expression(item_conf): - self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals)) - else: - self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) + for item_type in self.item_types: + if issubclass(item_type, ConfigComponent): + if item_type.is_instantiable(item_conf): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id, locator=self.locator)) + continue + if issubclass(item_type, ConfigExpression): + if item_type.is_expression(item_conf): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id, globals=self.globals)) + continue + if issubclass(item_type, ConfigItem): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id)) + raise TypeError( + f"item type must be subclass of `ConfigComponent`, `ConfigExpression`, `ConfigItem`, got: {item_type}." + ) @classmethod def load_config_file(cls, filepath: PathLike, **kwargs): diff --git a/tests/test_bundle_export.py b/tests/test_bundle_export.py index 99c4a4e225..d8b4875d5e 100644 --- a/tests/test_bundle_export.py +++ b/tests/test_bundle_export.py @@ -21,16 +21,17 @@ from monai.networks import save_state from tests.utils import skip_if_windows -TEST_CASE_1 = [ - os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), - os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), -] +TEST_CASE_1 = [""] + +TEST_CASE_2 = ["model"] @skip_if_windows class TestExport(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_export(self, meta_file, config_file): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_export(self, key_in_ckpt): + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg"} def_args_file = os.path.join(tempdir, "def_args.json") @@ -41,11 +42,11 @@ def test_export(self, meta_file, config_file): parser.export_config_file(config=def_args, filepath=def_args_file) parser.read_config(config_file) net = parser.get_parsed_content("network_def") - save_state(src=net, path=ckpt_file) + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) cmd = [sys.executable, "-m", "monai.bundle", "export", "network_def", "--filepath", ts_file] cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] - cmd += ["--args_file", def_args_file] + cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] ret = subprocess.check_call(cmd) self.assertEqual(ret, 0) self.assertTrue(os.path.exists(ts_file)) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index ce98be1214..421365ebc0 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -11,10 +11,11 @@ import unittest from unittest import skipUnless +from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem from parameterized import parameterized -from monai.bundle.config_parser import ConfigParser +from monai.bundle.config_parser import ConfigParser, ReferenceResolver from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond from monai.utils import min_version, optional_import @@ -57,6 +58,10 @@ def __call__(self, a, b): return self.compute(a, b) +class TestConfigComponent(ConfigComponent): + pass + + TEST_CASE_2 = [ { "basic_func": "$lambda x, y: x + y", @@ -73,7 +78,7 @@ def __call__(self, a, b): ] -class TestConfigComponent(unittest.TestCase): +class TestConfigParser(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -94,7 +99,7 @@ def test_config_content(self): @parameterized.expand([TEST_CASE_1]) @skipUnless(has_tv, "Requires torchvision >= 0.8.0.") def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser(config=config, globals={"monai": "monai"}) + parser = ConfigParser(config=config, globals={"monai": "monai"}, resolver=ReferenceResolver()) # test lazy instantiation with original config content parser["transform"]["transforms"][0]["keys"] = "label1" self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1") @@ -110,7 +115,11 @@ def test_parse(self, config, expected_ids, output_types): @parameterized.expand([TEST_CASE_2]) def test_function(self, config): - parser = ConfigParser(config=config, globals={"TestClass": TestClass}) + parser = ConfigParser( + config=config, + globals={"TestClass": TestClass}, + item_types=(TestConfigComponent, ConfigExpression, ConfigItem), + ) for id in config: func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content) @@ -122,4 +131,5 @@ def test_function(self, config): if __name__ == "__main__": + unittest.main() From 36d8469912eab15b47c5b7ec42e6e8105ec104e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Mar 2022 03:54:22 +0000 Subject: [PATCH 04/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 421365ebc0..28ef27128b 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -131,5 +131,5 @@ def test_function(self, config): if __name__ == "__main__": - + unittest.main() From 1fbd3853400c2833212b1dc7850cc66ffea9445c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Mar 2022 12:21:54 +0800 Subject: [PATCH 05/18] [DLMED] enhance doc-string Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 5 +++++ monai/bundle/scripts.py | 33 ++++++++++++++++++++++++++++----- tests/test_config_parser.py | 4 ++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a5d7da6996..71ee176288 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -75,6 +75,11 @@ class ConfigParser: The current supported globals and alias names are ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. These are MONAI's minimal dependencies. + item_types: list of supported config item types, must be subclass of `ConfigComponent`, + `ConfigExpression`, `ConfigItem`, will check the type in order for every config item. + if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``. + resolver: manage a set of ``ConfigItem`` and resolve the references between them. + if `None`, will create a default `ReferenceResolver` instance. See also: diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index dccf3c8049..46ba888a7b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -17,10 +17,9 @@ import torch -import torch from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser -from monai.config import PathLike, IgniteInfo +from monai.config import IgniteInfo, PathLike from monai.data import save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import @@ -267,8 +266,8 @@ def verify_net_in_out( p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. - args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, - `net_id` and override pairs. so that the command line inputs can be simplified. + args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`, + `device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. @@ -331,13 +330,37 @@ def verify_net_in_out( def export( net_id: Optional[str] = None, filepath: Optional[PathLike] = None, + ckpt_file: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, config_file: Optional[Union[str, Sequence[str]]] = None, - ckpt_file: Optional[str] = None, key_in_ckpt: Optional[str] = None, args_file: Optional[str] = None, **override, ): + """ + Export the model checkpoint the to the given filepath with metadata and config included as JSON files. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle export network --filepath --ckpt_file ... + + Args: + net_id: ID name of the network component in the config, it must be `torch.nn.Module`. + filepath: filepath to export, if filename has no extension it becomes `.pt`. + ckpt_file: filepath of the model checkpoint to load. + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model + weights. if not nested checkpoint, no need to set. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `net_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. + + """ _args = _update_args( args=args_file, net_id=net_id, diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 421365ebc0..da18528e33 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -11,10 +11,10 @@ import unittest from unittest import skipUnless -from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem from parameterized import parameterized +from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.config_parser import ConfigParser, ReferenceResolver from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond @@ -131,5 +131,5 @@ def test_function(self, config): if __name__ == "__main__": - + unittest.main() From 5c233530309554c4da6669e3c7518d03c6fae41b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Mar 2022 12:26:46 +0800 Subject: [PATCH 06/18] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 2 +- tests/test_config_parser.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 71ee176288..4f3f0c9065 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -76,7 +76,7 @@ class ConfigParser: ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. These are MONAI's minimal dependencies. item_types: list of supported config item types, must be subclass of `ConfigComponent`, - `ConfigExpression`, `ConfigItem`, will check the type in order for every config item. + `ConfigExpression`, `ConfigItem`, will check the types in order for every config item. if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``. resolver: manage a set of ``ConfigItem`` and resolve the references between them. if `None`, will create a default `ReferenceResolver` instance. diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index da18528e33..e85f13b6c6 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -131,5 +131,4 @@ def test_function(self, config): if __name__ == "__main__": - unittest.main() From b5b9f20e4a3f1f9b2c59b835ce6b5d5431815cd4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Mar 2022 18:51:49 +0800 Subject: [PATCH 07/18] [DLMED] add logging Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 46ba888a7b..0660d8e8ba 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -406,3 +406,5 @@ def export( meta_values=parser.get().pop("_meta_", None), more_extra_files={"config": json.dumps(parser.get()).encode()}, ) + logger.info(f"exported to TorchScript file: {path}.") + From a90d6333aafb48d9b5155fa14d837d503952f8c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Mar 2022 10:53:48 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 0660d8e8ba..3d4ef51e79 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -407,4 +407,3 @@ def export( more_extra_files={"config": json.dumps(parser.get()).encode()}, ) logger.info(f"exported to TorchScript file: {path}.") - From efd0b100f2ee47721e50fcc981ed6814f8746ef3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 18:55:20 +0800 Subject: [PATCH 09/18] [DLMED] remove customized configitem and resolver Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 39 +++++++---------------------------- tests/test_config_parser.py | 17 ++++----------- 2 files changed, 12 insertions(+), 44 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index ff561d6d15..23d4ac7c55 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -13,7 +13,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver @@ -76,11 +76,6 @@ class ConfigParser: The current supported globals and alias names are ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`. - item_types: list of supported config item types, must be subclass of `ConfigComponent`, - `ConfigExpression`, `ConfigItem`, will check the types in order for every config item. - if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``. - resolver: manage a set of ``ConfigItem`` and resolve the references between them. - if `None`, will create a default `ReferenceResolver` instance. See also: @@ -99,8 +94,6 @@ def __init__( config: Any = None, excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict[str, Any]] = None, - item_types: Optional[Union[Sequence[Type[ConfigItem]], Type[ConfigItem]]] = None, - resolver: Optional[ReferenceResolver] = None, ): self.config = None self.globals: Dict[str, Any] = {} @@ -110,17 +103,9 @@ def __init__( if _globals is not None: for k, v in _globals.items(): self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v - self.item_types = ( - (ConfigComponent, ConfigExpression, ConfigItem) if item_types is None else ensure_tuple(item_types) - ) self.locator = ComponentLocator(excludes=excludes) - if resolver is not None: - if not isinstance(resolver, ReferenceResolver): - raise TypeError(f"resolver must be subclass of ReferenceResolver, but got: {type(resolver)}.") - self.ref_resolver = resolver - else: - self.ref_resolver = ReferenceResolver() + self.ref_resolver = ReferenceResolver() if config is None: config = {self.meta_key: {}} self.set(config=config) @@ -310,20 +295,12 @@ def _do_parse(self, config, id: str = ""): # copy every config item to make them independent and add them to the resolver item_conf = deepcopy(config) - for item_type in self.item_types: - if issubclass(item_type, ConfigComponent): - if item_type.is_instantiable(item_conf): - return self.ref_resolver.add_item(item_type(config=item_conf, id=id, locator=self.locator)) - continue - if issubclass(item_type, ConfigExpression): - if item_type.is_expression(item_conf): - return self.ref_resolver.add_item(item_type(config=item_conf, id=id, globals=self.globals)) - continue - if issubclass(item_type, ConfigItem): - return self.ref_resolver.add_item(item_type(config=item_conf, id=id)) - raise TypeError( - f"item type must be subclass of `ConfigComponent`, `ConfigExpression`, `ConfigItem`, got: {item_type}." - ) + if ConfigComponent.is_instantiable(item_conf): + self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator)) + elif ConfigExpression.is_expression(item_conf): + self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals)) + else: + self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) @classmethod def load_config_file(cls, filepath: PathLike, **kwargs): diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index e85f13b6c6..ce98be1214 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -14,8 +14,7 @@ from parameterized import parameterized -from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem -from monai.bundle.config_parser import ConfigParser, ReferenceResolver +from monai.bundle.config_parser import ConfigParser from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond from monai.utils import min_version, optional_import @@ -58,10 +57,6 @@ def __call__(self, a, b): return self.compute(a, b) -class TestConfigComponent(ConfigComponent): - pass - - TEST_CASE_2 = [ { "basic_func": "$lambda x, y: x + y", @@ -78,7 +73,7 @@ class TestConfigComponent(ConfigComponent): ] -class TestConfigParser(unittest.TestCase): +class TestConfigComponent(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -99,7 +94,7 @@ def test_config_content(self): @parameterized.expand([TEST_CASE_1]) @skipUnless(has_tv, "Requires torchvision >= 0.8.0.") def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser(config=config, globals={"monai": "monai"}, resolver=ReferenceResolver()) + parser = ConfigParser(config=config, globals={"monai": "monai"}) # test lazy instantiation with original config content parser["transform"]["transforms"][0]["keys"] = "label1" self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1") @@ -115,11 +110,7 @@ def test_parse(self, config, expected_ids, output_types): @parameterized.expand([TEST_CASE_2]) def test_function(self, config): - parser = ConfigParser( - config=config, - globals={"TestClass": TestClass}, - item_types=(TestConfigComponent, ConfigExpression, ConfigItem), - ) + parser = ConfigParser(config=config, globals={"TestClass": TestClass}) for id in config: func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content) From 5498c53b975dc193a9d9b41010ccd8ef558824a7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 22:52:41 +0800 Subject: [PATCH 10/18] [DLMED] remove assert ret Signed-off-by: Nic Ma --- tests/test_bundle_export.py | 3 +-- tests/test_bundle_verify_metadata.py | 6 ++---- tests/test_bundle_verify_net.py | 3 +-- tests/test_integration_bundle_run.py | 9 +++------ 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/tests/test_bundle_export.py b/tests/test_bundle_export.py index d8b4875d5e..e6aec05fc1 100644 --- a/tests/test_bundle_export.py +++ b/tests/test_bundle_export.py @@ -47,8 +47,7 @@ def test_export(self, key_in_ckpt): cmd = [sys.executable, "-m", "monai.bundle", "export", "network_def", "--filepath", ts_file] cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + subprocess.check_call(cmd) self.assertTrue(os.path.exists(ts_file)) diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index b018c9a568..9449d8f73b 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -48,8 +48,7 @@ def test_verify(self, meta_file, schema_file): cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + subprocess.check_call(cmd) def test_verify_error(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -61,8 +60,7 @@ def test_verify_error(self): json.dump(meta_dict, f) cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + subprocess.check_call(cmd) if __name__ == "__main__": diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 62f99aab99..5012cc85f4 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -41,8 +41,7 @@ def test_verify(self, meta_file, config_file): test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) - ret = subprocess.check_call(cmd, env=test_env) - self.assertEqual(ret, 0) + subprocess.check_call(cmd, env=test_env) if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index e6d4dfd89f..90f5aa6408 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -50,8 +50,7 @@ def test_tiny(self): with open(config_file, "w") as f: json.dump({"": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}}, f) cmd = [sys.executable, "-m", "monai.bundle", "run", "--config_file", config_file] - ret = subprocess.check_call(cmd) - self.assertEqual(ret, 0) + subprocess.check_call(cmd) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, config_file, expected_shape): @@ -92,16 +91,14 @@ def test_shape(self, config_file, expected_shape): la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) - ret = subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) - self.assertEqual(ret, 0) + subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) # here test the script with `google fire` tool as CLI cmd = "-m fire monai.bundle.scripts run --runner_id evaluator" cmd += f" --evaluator#amp False {override}" la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] - ret = subprocess.check_call(la, env=test_env) - self.assertEqual(ret, 0) + subprocess.check_call(la, env=test_env) self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) From c87e4584b08eae7746da34691e1275e078cd6778 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 23:02:58 +0800 Subject: [PATCH 11/18] [DLMED] update to ckpt_export Signed-off-by: Nic Ma --- docs/source/bundle.rst | 2 +- monai/bundle/__init__.py | 2 +- monai/bundle/__main__.py | 2 +- monai/bundle/scripts.py | 4 ++-- tests/{test_bundle_export.py => test_bundle_ckpt_export.py} | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) rename tests/{test_bundle_export.py => test_bundle_ckpt_export.py} (93%) diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 7b6a95dcbb..297409cd7e 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -35,7 +35,7 @@ Model Bundle `Scripts` --------- -.. autofunction:: export +.. autofunction:: ckpt_export .. autofunction:: run .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 6cccbf4002..d6a452b5a4 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import export, run, verify_metadata, verify_net_in_out +from .scripts import ckpt_export, run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 6a754bc985..d77b396e79 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import export, run, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ckpt_export, run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3a2415a96d..f864db9f81 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -327,7 +327,7 @@ def verify_net_in_out( logger.info("data shape of network is verified with no error.") -def export( +def ckpt_export( net_id: Optional[str] = None, filepath: Optional[PathLike] = None, ckpt_file: Optional[str] = None, @@ -338,7 +338,7 @@ def export( **override, ): """ - Export the model checkpoint the to the given filepath with metadata and config included as JSON files. + Export the model checkpoint to the given filepath with metadata and config included as JSON files. Typical usage examples: diff --git a/tests/test_bundle_export.py b/tests/test_bundle_ckpt_export.py similarity index 93% rename from tests/test_bundle_export.py rename to tests/test_bundle_ckpt_export.py index e6aec05fc1..9dfec681f9 100644 --- a/tests/test_bundle_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -27,7 +27,7 @@ @skip_if_windows -class TestExport(unittest.TestCase): +class TestCKPTExport(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_export(self, key_in_ckpt): meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") @@ -44,7 +44,7 @@ def test_export(self, key_in_ckpt): net = parser.get_parsed_content("network_def") save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) - cmd = [sys.executable, "-m", "monai.bundle", "export", "network_def", "--filepath", ts_file] + cmd = [sys.executable, "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] subprocess.check_call(cmd) From 2344abb8f2f0e9bb25506cbbe92ee9c1a50f40b3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 23:12:47 +0800 Subject: [PATCH 12/18] [DLMED] update to .ts Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 2 +- monai/data/torchscript_utils.py | 6 +++--- tests/test_torchscript_utils.py | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index f864db9f81..9a346c5f44 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -348,7 +348,7 @@ def ckpt_export( Args: net_id: ID name of the network component in the config, it must be `torch.nn.Module`. - filepath: filepath to export, if filename has no extension it becomes `.pt`. + filepath: filepath to export, if filename has no extension it becomes `.ts`. ckpt_file: filepath of the model checkpoint to load. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. config_file: filepath of the config file, if `None`, must be provided in `args_file`. diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 585db14712..61477e8ca9 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -56,12 +56,12 @@ def save_net_with_metadata( save_net_with_metadata(m, "test", meta_values=meta) # load the network back, `loaded_meta` has same data as `meta` plus version information - loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + loaded_net, loaded_meta, _ = load_net_with_metadata("test.ts") Args: jit_obj: object to save, should be generated by `script` or `trace`. - filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`. include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. @@ -97,7 +97,7 @@ def save_net_with_metadata( if isinstance(filename_prefix_or_stream, str): filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) if ext == "": - ext = ".pt" + ext = ".ts" if append_timestamp: filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index b26d41345a..d6bea09ed6 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -34,7 +34,7 @@ def test_save_net_with_metadata(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test") - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) def test_save_net_with_metadata_ext(self): """Save a network without metadata to a file.""" @@ -54,7 +54,7 @@ def test_save_net_with_metadata_with_extra(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) def test_load_net_with_metadata(self): """Save then load a network with no metadata or other extra files.""" @@ -62,7 +62,7 @@ def test_load_net_with_metadata(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test") - _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.ts") del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be @@ -77,7 +77,7 @@ def test_load_net_with_metadata_with_extra(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) - _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.ts") del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be @@ -98,9 +98,9 @@ def test_save_load_more_extra_files(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) - self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + self.assertTrue(os.path.isfile(f"{tempdir}/test.ts")) - _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.ts", more_extra_files=("test.txt",)) if pytorch_after(1, 7): self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) From a76843cbec2a3b893fadb36d85a7ed8ec17378d9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 23:17:55 +0800 Subject: [PATCH 13/18] [DLMED] update min test Signed-off-by: Nic Ma --- tests/min_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index 3d93d96dda..9bf95f3f49 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -162,7 +162,7 @@ def run_testsuit(): "test_parallel_execution_dist", "test_bundle_verify_metadata", "test_bundle_verify_net", - "test_bundle_export", + "test_bundle_ckpt_export", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From db9800679cfd1be6818d27862b0313b081cfd71a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 23 Mar 2022 07:22:09 +0800 Subject: [PATCH 14/18] [DLMED] simplify args Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 79 +++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9a346c5f44..13a5a19623 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -16,6 +16,7 @@ from typing import Dict, Optional, Sequence, Tuple, Union import torch +from torch.cuda import is_available from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser @@ -58,6 +59,14 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr return args_ +def _pop_args(src: Dict, *args, **kwargs): + """ + Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`. + + """ + return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) + + def _log_input_summary(tag, args: Dict): logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): @@ -152,20 +161,22 @@ def run( if "config_file" not in _args: raise ValueError(f"`config_file` is required for 'monai.bundle run'.\n{run.__doc__}") _log_input_summary(tag="run", args=_args) + config_file_, meta_file_, runner_id_ = _pop_args(_args, "config_file", meta_file=None, runner_id="") parser = ConfigParser() - parser.read_config(f=_args.pop("config_file")) - if "meta_file" in _args: - parser.read_meta(f=_args.pop("meta_file")) - id = _args.pop("runner_id", "") + parser.read_config(f=config_file_) + if meta_file_ is not None: + parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v - workflow = parser.get_parsed_content(id=id) + workflow = parser.get_parsed_content(id=runner_id_) if not hasattr(workflow, "run"): - raise ValueError(f"The parsed workflow {type(workflow)} (id={id}) does not have a `run` method.\n{run.__doc__}") + raise ValueError( + f"The parsed workflow {type(workflow)} (id={runner_id_}) does not have a `run` method.\n{run.__doc__}" + ) return workflow.run() @@ -207,22 +218,16 @@ def verify_metadata( **kwargs, ) _log_input_summary(tag="verify_metadata", args=_args) + filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args( + _args, "filepath", "meta_file", create_dir=True, hash_val=None, hash_type="md5" + ) - filepath_ = _args.pop("filepath") - create_dir_ = _args.pop("create_dir", True) check_parent_dir(path=filepath_, create_dir=create_dir_) - - metadata = ConfigParser.load_config_files(files=_args.pop("meta_file")) + metadata = ConfigParser.load_config_files(files=meta_file_) url = metadata.get("schema") if url is None: raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.") - download_url( - url=url, - filepath=filepath_, - hash_val=_args.pop("hash_val", None), - hash_type=_args.pop("hash_type", "md5"), - progress=True, - ) + download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True) schema = ConfigParser.load_config_file(filepath=filepath_) try: @@ -285,22 +290,20 @@ def verify_net_in_out( **override, ) _log_input_summary(tag="verify_net_in_out", args=_args) + config_file_, meta_file_, net_id_, device_, p_, n_, any_ = _pop_args( + _args, "config_file", "meta_file", net_id="", device="cuda:0" if is_available() else "cpu", p=1, n=1, any=1 + ) parser = ConfigParser() - parser.read_config(f=_args.pop("config_file")) - parser.read_meta(f=_args.pop("meta_file")) - id = _args.pop("net_id", "") - device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu")) - p = _args.pop("p", 1) - n = _args.pop("n", 1) - any = _args.pop("any", 1) + parser.read_config(f=config_file_) + parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v try: - key: str = id # mark the full id when KeyError + key: str = net_id_ # mark the full id when KeyError net = parser.get_parsed_content(key).to(device_) key = "_meta_#network_data_format#inputs#image#num_channels" input_channels = parser[key] @@ -317,7 +320,7 @@ def verify_net_in_out( net.eval() with torch.no_grad(): - spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_) # type: ignore test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) output = net(test_data) if output.shape[1] != output_channels: @@ -372,38 +375,36 @@ def ckpt_export( **override, ) _log_input_summary(tag="export", args=_args) + filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args( + _args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt="" + ) parser = ConfigParser() - config_file_ = _args.pop("config_file") + parser.read_config(f=config_file_) - meta_file = _args.pop("meta_file") - if meta_file is not None: - parser.read_meta(f=meta_file) - id = _args.pop("net_id", "") - path = _args.pop("filepath") - ckpt = torch.load(_args.pop("ckpt_file")) - key = _args.pop("key_in_ckpt", "") + if meta_file_ is not None: + parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v - net = parser.get_parsed_content(id) + net = parser.get_parsed_content(net_id_) if has_ignite: # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key: net}, checkpoint=ckpt) + Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) else: - copy_model_state(dst=net, src=ckpt if key == "" else ckpt[key]) + copy_model_state(dst=net, src=ckpt_file_ if key_in_ckpt_ == "" else ckpt_file_[key_in_ckpt_]) # convert to TorchScript model and save with meta data, config content net = convert_to_torchscript(model=net) save_net_with_metadata( jit_obj=net, - filename_prefix_or_stream=path, + filename_prefix_or_stream=filepath_, include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), more_extra_files={"config": json.dumps(parser.get()).encode()}, ) - logger.info(f"exported to TorchScript file: {path}.") + logger.info(f"exported to TorchScript file: {filepath_}.") From f352bd155954688cf22162944e0469cec4a45763 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Mar 2022 09:55:11 +0000 Subject: [PATCH 15/18] fixes integration tests Signed-off-by: Wenqi Li --- tests/test_integration_workers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py index 1f12f81712..9fa2beb3fb 100644 --- a/tests/test_integration_workers.py +++ b/tests/test_integration_workers.py @@ -16,7 +16,7 @@ from monai.data import DataLoader from monai.utils import set_determinism -from tests.utils import DistTestCase, TimedCall, skip_if_no_cuda, skip_if_quick +from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick def run_loading_test(num_workers=50, device="cuda:0" if torch.cuda.is_available() else "cpu", pw=False): @@ -38,15 +38,19 @@ def run_loading_test(num_workers=50, device="cuda:0" if torch.cuda.is_available( @skip_if_quick @skip_if_no_cuda +@SkipIfBeforePyTorchVersion((1, 9)) class IntegrationLoading(DistTestCase): def tearDown(self): set_determinism(seed=None) @TimedCall(seconds=5000, skip_timing=not torch.cuda.is_available(), daemon=False) def test_timing(self): - for pw, expected in zip((False, True), ((6966, 7714), (6966, 4112))): + expected = None + for pw in (False, True): result = run_loading_test(pw=pw) - np.testing.assert_allclose(result, expected) + if expected is None: + expected = result + np.testing.assert_allclose(result, expected) # test for deterministic in two settings if __name__ == "__main__": From 26f3ec4e5bebe5edb443f22de0d2ed408fbd8f75 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 23 Mar 2022 19:28:58 +0800 Subject: [PATCH 16/18] [DLMED] change to coverage Signed-off-by: Nic Ma --- tests/test_bundle_ckpt_export.py | 3 +-- tests/test_bundle_verify_metadata.py | 8 ++------ tests/test_bundle_verify_net.py | 5 ++--- tests/test_integration_bundle_run.py | 6 +++--- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 9dfec681f9..0f7d0f7d35 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -11,7 +11,6 @@ import os import subprocess -import sys import tempfile import unittest @@ -44,7 +43,7 @@ def test_export(self, key_in_ckpt): net = parser.get_parsed_content("network_def") save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) - cmd = [sys.executable, "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] subprocess.check_call(cmd) diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index 9449d8f73b..c816c081eb 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -10,10 +10,8 @@ # limitations under the License. import json -import logging import os import subprocess -import sys import tempfile import unittest @@ -40,18 +38,16 @@ def setUp(self): @parameterized.expand([TEST_CASE_1]) def test_verify(self, meta_file, schema_file): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg"} def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file] subprocess.check_call(cmd) def test_verify_error(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) with tempfile.TemporaryDirectory() as tempdir: filepath = os.path.join(tempdir, "schema.json") metafile = os.path.join(tempdir, "metadata.json") @@ -59,7 +55,7 @@ def test_verify_error(self): with open(metafile, "w") as f: json.dump(meta_dict, f) - cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] subprocess.check_call(cmd) diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 5012cc85f4..33d480d83f 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -11,7 +11,6 @@ import os import subprocess -import sys import tempfile import unittest @@ -35,8 +34,8 @@ def test_verify(self, meta_file, config_file): def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] - cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"] + cmd += [meta_file, "--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] test_env = os.environ.copy() diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 90f5aa6408..af97d9e9ad 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -49,7 +49,7 @@ def test_tiny(self): config_file = os.path.join(self.data_dir, "tiny_config.json") with open(config_file, "w") as f: json.dump({"": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}}, f) - cmd = [sys.executable, "-m", "monai.bundle", "run", "--config_file", config_file] + cmd = ["coverage", "run", "-m", "monai.bundle", "run", "--config_file", config_file] subprocess.check_call(cmd) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -88,7 +88,7 @@ def test_shape(self, config_file, expected_shape): override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" # test with `monai.bundle` as CLI entry directly cmd = f"-m monai.bundle run evaluator --postprocessing#transforms#2#output_postfix seg {override}" - la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) @@ -97,7 +97,7 @@ def test_shape(self, config_file, expected_shape): # here test the script with `google fire` tool as CLI cmd = "-m fire monai.bundle.scripts run --runner_id evaluator" cmd += f" --evaluator#amp False {override}" - la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] subprocess.check_call(la, env=test_env) self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) From fce2bd74a6963b714ccbb8da9e9d3675bb913e1d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Mar 2022 12:16:11 +0000 Subject: [PATCH 17/18] update integration tests Signed-off-by: Wenqi Li --- tests/test_integration_workers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py index 9fa2beb3fb..21515d1f82 100644 --- a/tests/test_integration_workers.py +++ b/tests/test_integration_workers.py @@ -49,8 +49,8 @@ def test_timing(self): for pw in (False, True): result = run_loading_test(pw=pw) if expected is None: - expected = result - np.testing.assert_allclose(result, expected) # test for deterministic in two settings + expected = result[0] + np.testing.assert_allclose(result[0], expected) # test for deterministic first epoch in two settings if __name__ == "__main__": From 1db8f00ca73eed0f7f900286483c16d7b5f68d08 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 24 Mar 2022 09:09:22 +0800 Subject: [PATCH 18/18] [DLMED] fix blossom Signed-off-by: Nic Ma --- tests/test_integration_workers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py index 21515d1f82..654af5a89c 100644 --- a/tests/test_integration_workers.py +++ b/tests/test_integration_workers.py @@ -19,9 +19,11 @@ from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick -def run_loading_test(num_workers=50, device="cuda:0" if torch.cuda.is_available() else "cpu", pw=False): +def run_loading_test(num_workers=50, device=None, pw=False): """multi workers stress tests""" set_determinism(seed=0) + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" train_ds = list(range(10000)) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers, persistent_workers=pw) answer = []