From 1927d7806f1c0bd31c2b108ea35ef5fa6688cdb4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 6 Apr 2022 12:40:05 +0800 Subject: [PATCH] [DLMED] add allow_missing_reference Signed-off-by: Nic Ma --- monai/bundle/reference_resolver.py | 18 ++++++++++++++++-- tests/test_config_parser.py | 25 ++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index f9f73c9c71..c169389e0d 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import re +import warnings from typing import Any, Dict, Optional, Sequence, Set from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem @@ -50,6 +52,8 @@ class ReferenceResolver: ref = ID_REF_KEY # reference prefix # match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key" id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*") + # if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config + allow_missing_reference = False if os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "0") == "0" else True def __init__(self, items: Optional[Sequence[ConfigItem]] = None): # save the items in a dictionary with the `ConfigItem.id` as key @@ -140,7 +144,12 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** try: look_up_option(d, self.items, print_all_options=False) except ValueError as err: - raise ValueError(f"the referring item `@{d}` is not defined in the config content.") from err + msg = f"the referring item `@{d}` is not defined in the config content." + if self.allow_missing_reference: + warnings.warn(msg) + continue + else: + raise ValueError(msg) from err # recursively resolve the reference first self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs) waiting_list.discard(d) @@ -210,7 +219,12 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str: for item in result: ref_id = item[len(cls.ref) :] # remove the ref prefix "@" if ref_id not in refs: - raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + msg = f"can not find expected ID '{ref_id}' in the references." + if cls.allow_missing_reference: + warnings.warn(msg) + continue + else: + raise KeyError(msg) if value_is_expr: # replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}` value = value.replace(item, f"{cls._vars}['{ref_id}']") diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 9c727c29ac..9ab002f7af 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -16,7 +16,7 @@ from parameterized import parameterized -from monai.bundle.config_parser import ConfigParser +from monai.bundle import ConfigParser, ReferenceResolver from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond from monai.utils import min_version, optional_import @@ -86,6 +86,8 @@ def __call__(self, a, b): } ] +TEST_CASE_4 = [{"A": 1, "B": "@A", "C": "@D", "E": "$'test' + '@F'"}] + class TestConfigParser(unittest.TestCase): def test_config_content(self): @@ -154,6 +156,27 @@ def test_macro_replace(self): parser.resolve_macro_and_relative_ids() self.assertEqual(str(parser.get()), str({"A": {"B": 1, "C": 2}, "D": [3, 1, 3, 4]})) + @parameterized.expand([TEST_CASE_4]) + def test_allow_missing_reference(self, config): + default = ReferenceResolver.allow_missing_reference + ReferenceResolver.allow_missing_reference = True + parser = ConfigParser(config=config) + + for id in config: + item = parser.get_parsed_content(id=id) + if id in ("A", "B"): + self.assertEqual(item, 1) + elif id == "C": + self.assertEqual(item, "@D") + elif id == "E": + self.assertEqual(item, "test@F") + + # restore the default value + ReferenceResolver.allow_missing_reference = default + with self.assertRaises(ValueError): + parser.parse() + parser.get_parsed_content(id="E") + if __name__ == "__main__": unittest.main()