diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a9107a10e9..d899cfcb08 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -13,7 +13,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver @@ -253,7 +253,7 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): content.update(self.load_config_files(f, **kwargs)) self.set(config=content) - def _do_resolve(self, config: Any, id: str = ""): + def _do_resolve(self, config: Any, id: str = "", waiting_list: Optional[Set[str]] = None): """ Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, `@##A` means `A` in the upper level. and replace the macro tokens with target content, @@ -266,18 +266,32 @@ def _do_resolve(self, config: Any, id: str = ""): go one level further into the nested structures. Use digits indexing from "0" for list or other strings for dict. For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + waiting_list: set of macro replacement ids pending to be resolved. + It's used to detect circular references such as: + `{"A": {"dep": "%B"}, "B": {"dep": "%A"}}`. """ + if waiting_list is None: + waiting_list = set() if isinstance(config, (dict, list)): for k, v in enumerate(config) if isinstance(config, list) else config.items(): sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k - config[k] = self._do_resolve(v, sub_id) + config[k] = self._do_resolve(v, sub_id, waiting_list) if isinstance(config, str): config = self.resolve_relative_ids(id, config) if config.startswith(MACRO_KEY): + waiting_list.add(id) path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) - parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) - return self._do_resolve(config=deepcopy(parser[ids])) + if not path: + # if the target id is in the waiting list, that's circular references + if ids in waiting_list: + raise ValueError(f"detected circular references in macro replacement '{ids}' for id='{id}'.") + parser = ConfigParser(config=self.get()) + config = self._do_resolve(deepcopy(parser[ids]), ids, waiting_list) + else: + # don't support recursive macro replacement in another config file + config = ConfigParser(config=ConfigParser.load_config_file(path))[ids] + waiting_list.discard(id) return config def resolve_macro_and_relative_ids(self): diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index f9f73c9c71..a98af82687 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -106,7 +106,7 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** id: id name of ``ConfigItem`` to be resolved. waiting_list: set of ids pending to be resolved. It's used to detect circular references such as: - `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + `{"A": {"dep": "@B"}, "B": {"dep": "@A"}}`. kwargs: keyword arguments to pass to ``_resolve_one_item()``. Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 8b1076b1f7..f7982bfe1a 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest from unittest import skipUnless @@ -142,6 +144,22 @@ def test_relative_id(self, config): if isinstance(item, dict): self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]})) + def test_macro_replace(self): + with tempfile.TemporaryDirectory() as tempdir: + another_file = os.path.join(tempdir, "another.json") + ConfigParser.export_config_file(config={"E": 4}, filepath=another_file) + # test relative id, recursive macro replacement, and macro in another file + config = {"A": {"B": 1, "C": 2}, "D": [3, "%A#B", "%#1", f"%{another_file}#E"]} + parser = ConfigParser(config=config) + parser.resolve_macro_and_relative_ids() + self.assertEqual(str(parser.get()), str({"A": {"B": 1, "C": 2}, "D": [3, 1, 1, 4]})) + + def test_circular_macro_replace(self): + config = {"A": "%B", "B": {"args": [1, 2, "%A"]}} + parser = ConfigParser(config=config) + with self.assertRaises(ValueError): + parser.resolve_macro_and_relative_ids() + if __name__ == "__main__": unittest.main()