Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import warnings
from typing import Any, Dict, Optional, Sequence, Set

from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
Expand Down Expand Up @@ -50,6 +52,8 @@ class ReferenceResolver:
ref = ID_REF_KEY # reference prefix
# match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key"
id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*")
# if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config
allow_missing_reference = False if os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "0") == "0" else True

def __init__(self, items: Optional[Sequence[ConfigItem]] = None):
# save the items in a dictionary with the `ConfigItem.id` as key
Expand Down Expand Up @@ -140,7 +144,12 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
try:
look_up_option(d, self.items, print_all_options=False)
except ValueError as err:
raise ValueError(f"the referring item `@{d}` is not defined in the config content.") from err
msg = f"the referring item `@{d}` is not defined in the config content."
if self.allow_missing_reference:
warnings.warn(msg)
continue
else:
raise ValueError(msg) from err
# recursively resolve the reference first
self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs)
waiting_list.discard(d)
Expand Down Expand Up @@ -210,7 +219,12 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str:
for item in result:
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
if ref_id not in refs:
raise KeyError(f"can not find expected ID '{ref_id}' in the references.")
msg = f"can not find expected ID '{ref_id}' in the references."
if cls.allow_missing_reference:
warnings.warn(msg)
continue
else:
raise KeyError(msg)
if value_is_expr:
# replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}`
value = value.replace(item, f"{cls._vars}['{ref_id}']")
Expand Down
25 changes: 24 additions & 1 deletion tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from parameterized import parameterized

from monai.bundle.config_parser import ConfigParser
from monai.bundle import ConfigParser, ReferenceResolver
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -86,6 +86,8 @@ def __call__(self, a, b):
}
]

TEST_CASE_4 = [{"A": 1, "B": "@A", "C": "@D", "E": "$'test' + '@F'"}]


class TestConfigParser(unittest.TestCase):
def test_config_content(self):
Expand Down Expand Up @@ -154,6 +156,27 @@ def test_macro_replace(self):
parser.resolve_macro_and_relative_ids()
self.assertEqual(str(parser.get()), str({"A": {"B": 1, "C": 2}, "D": [3, 1, 3, 4]}))

@parameterized.expand([TEST_CASE_4])
def test_allow_missing_reference(self, config):
default = ReferenceResolver.allow_missing_reference
ReferenceResolver.allow_missing_reference = True
parser = ConfigParser(config=config)

for id in config:
item = parser.get_parsed_content(id=id)
if id in ("A", "B"):
self.assertEqual(item, 1)
elif id == "C":
self.assertEqual(item, "@D")
elif id == "E":
self.assertEqual(item, "test@F")

# restore the default value
ReferenceResolver.allow_missing_reference = default
with self.assertRaises(ValueError):
parser.parse()
parser.get_parsed_content(id="E")


if __name__ == "__main__":
unittest.main()