diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index ed40d7099a..ff2eaf4053 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -221,21 +221,23 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str: result = cls.id_matcher.findall(value) value_is_expr = ConfigExpression.is_expression(value) for item in result: - ref_id = item[len(cls.ref) :] # remove the ref prefix "@" - if ref_id not in refs: - msg = f"can not find expected ID '{ref_id}' in the references." - if cls.allow_missing_reference: - warnings.warn(msg) - continue - else: - raise KeyError(msg) - if value_is_expr: - # replace with local code, `{"__local_refs": self.resolved_content}` will be added to - # the `globals` argument of python `eval` in the `evaluate` - value = value.replace(item, f"{cls._vars}['{ref_id}']") - elif value == item: - # the whole content is "@XXX", it will avoid the case that regular string contains "@" - value = refs[ref_id] + # only update reference when string starts with "$" or the whole content is "@XXX" + if value_is_expr or value == item: + ref_id = item[len(cls.ref) :] # remove the ref prefix "@" + if ref_id not in refs: + msg = f"can not find expected ID '{ref_id}' in the references." + if cls.allow_missing_reference: + warnings.warn(msg) + continue + else: + raise KeyError(msg) + if value_is_expr: + # replace with local code, `{"__local_refs": self.resolved_content}` will be added to + # the `globals` argument of python `eval` in the `evaluate` + value = value.replace(item, f"{cls._vars}['{ref_id}']") + elif value == item: + # the whole content is "@XXX", it will avoid the case that regular string contains "@" + value = refs[ref_id] return value @classmethod diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6312b088d2..d61d4105ec 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -14,6 +14,7 @@ import os import pprint import re +import time import warnings from logging.config import fileConfig from pathlib import Path @@ -489,6 +490,18 @@ def patch_bundle_tracking(parser: ConfigParser, settings: dict): handlers.append(v) elif k not in parser: parser[k] = v + # save the executed config into file + default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" + filepath = parser.get("execute_config", None) + if filepath is None: + if "output_dir" not in parser: + # if no "output_dir" in the bundle config, default to "/eval" + parser["output_dir"] = "$@bundle_root + '/eval'" + # experiment management tools can refer to this config item to track the config info + parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" + filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + parser.export_config_file(parser.get(), filepath) def run( diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 8ffcd11208..77112b0db3 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -104,9 +104,14 @@ DEFAULT_MLFLOW_SETTINGS = { "handlers_id": DEFAULT_HANDLERS_ID, "configs": { - "tracking_uri": "$@output_dir + '/mlruns'", + # if no "output_dir" in the bundle config, default to "/eval" + "output_dir": "$@bundle_root + '/eval'", + # use URI to support linux, mac and windows os + "tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlruns'", "experiment_name": "monai_experiment", "run_name": None, + # may fill it at runtime + "execute_config": None, "is_not_rank0": ( "$torch.distributed.is_available() \ and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" @@ -118,6 +123,7 @@ "tracking_uri": "@tracking_uri", "experiment_name": "@experiment_name", "run_name": "@run_name", + "artifacts": "@execute_config", "iteration_log": True, "epoch_log": True, "tag_name": "train_loss", @@ -140,6 +146,7 @@ "tracking_uri": "@tracking_uri", "experiment_name": "@experiment_name", "run_name": "@run_name", + "artifacts": "@execute_config", "iteration_log": False, "close_on_complete": True, }, diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 21d3621090..6dc12a0254 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -73,6 +73,7 @@ is_scalar_tensor, issequenceiterable, list_to_dict, + path_to_uri, progress_bar, sample_slices, save_obj, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d400e7b64c..3569a76276 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -56,6 +56,7 @@ "check_parent_dir", "save_obj", "label_union", + "path_to_uri", ] _seed = None @@ -584,3 +585,14 @@ def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs): threshold: threshold value to activate the sigmoid function. """ return torch.argmax(x, **kwargs) if not sigmoid else (x > threshold).int() + + +def path_to_uri(path: PathLike) -> str: + """ + Convert a file path to URI. if not absolute path, will convert to absolute path first. + + Args: + path: input file path to convert, can be a string or `Path` object. + + """ + return Path(path).absolute().as_uri() diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index 426bb44cfc..cce01a169f 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -13,7 +13,6 @@ import shutil import tempfile import unittest -from pathlib import Path from parameterized import parameterized @@ -22,6 +21,7 @@ from monai.fl.client.monai_algo import MonaiAlgo from monai.fl.utils.constants import ExtraItems from monai.fl.utils.exchange_object import ExchangeObject +from monai.utils import path_to_uri from tests.utils import SkipIfNoModule _root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__))) @@ -151,12 +151,13 @@ def test_train(self, input_params): input_params["tracking"] = { "handlers_id": DEFAULT_HANDLERS_ID, "configs": { + "execute_config": f"{data_dir}/config_executed.json", "trainer": { "_target_": "MLFlowHandler", - "tracking_uri": Path(data_dir).as_uri() + "/mlflow_override", + "tracking_uri": path_to_uri(data_dir) + "/mlflow_override", "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", "close_on_complete": True, - } + }, }, } @@ -177,6 +178,7 @@ def test_train(self, input_params): algo.train(data=data, extra={}) algo.finalize() self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override")) + self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json")) shutil.rmtree(data_dir) @parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3]) diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index f79707384f..3f43f97fbe 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -13,12 +13,12 @@ import os import tempfile import unittest -from pathlib import Path import numpy as np from ignite.engine import Engine, Events from monai.handlers import MLFlowHandler +from monai.utils import path_to_uri class TestHandlerMLFlow(unittest.TestCase): @@ -49,7 +49,7 @@ def _update_metric(engine): handler = MLFlowHandler( iteration_log=False, epoch_log=True, - tracking_uri=Path(test_path).as_uri(), + tracking_uri=path_to_uri(test_path), state_attributes=["test"], experiment_param=experiment_param, artifacts=[artifact_path], diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 986dd77076..1b3583a911 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -15,7 +15,7 @@ import sys import tempfile import unittest -from pathlib import Path +from glob import glob import nibabel as nib import numpy as np @@ -24,6 +24,7 @@ from monai.bundle import ConfigParser from monai.bundle.utils import DEFAULT_HANDLERS_ID from monai.transforms import LoadImage +from monai.utils import path_to_uri from tests.utils import command_line_tests TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)] @@ -85,7 +86,7 @@ def test_shape(self, config_file, expected_shape): "no_epoch": True, # test override config in the settings file "evaluator": { "_target_": "MLFlowHandler", - "tracking_uri": "$@output_dir + '/mlflow_override1'", + "tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlflow_override1'", "iteration_log": "@no_epoch", }, }, @@ -105,14 +106,12 @@ def test_shape(self, config_file, expected_shape): json.dump("Dataset", f) if sys.platform == "win32": - outdir = Path(tempdir).as_uri() override = "--network $@network_def.to(@device) --dataset#_target_ Dataset" else: - outdir = tempdir override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" # test with `monai.bundle` as CLI entry directly cmd = "-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg" - cmd += f" {override} --no_epoch False --save_dir {tempdir} --output_dir {outdir}" + cmd += f" {override} --no_epoch False --output_dir {tempdir}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) @@ -121,14 +120,16 @@ def test_shape(self, config_file, expected_shape): self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override1")) - tracking_uri = outdir + "/mlflow_override2" # test override experiment management configs + tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs # here test the script with `google fire` tool as CLI cmd = "-m fire monai.bundle.scripts run --runner_id evaluating --tracking mlflow --evaluator#amp False" - cmd += f" --tracking_uri {tracking_uri} {override} --save_dir {tempdir} --output_dir {outdir}" + cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] command_line_tests(la) self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override2")) + # test the saved execution configs + self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2) if __name__ == "__main__": diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 1560f61721..c222667101 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,6 +1,5 @@ { "dataset_dir": "/workspace/data/Task09_Spleen", - "save_dir": "need_override", "output_dir": "need override", "prediction_shape": "prediction shape:", "import_glob": "$import glob", @@ -89,7 +88,7 @@ { "_target_": "SaveImaged", "keys": "pred", - "output_dir": "@save_dir" + "output_dir": "@output_dir" }, { "_target_": "Lambdad", diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 072d93e883..a289b549db 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,6 +1,5 @@ --- dataset_dir: "/workspace/data/Task09_Spleen" -save_dir: "need_override" output_dir: "need override" prediction_shape: "prediction shape:" device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" @@ -65,7 +64,7 @@ postprocessing: argmax: true - _target_: SaveImaged keys: pred - output_dir: "@save_dir" + output_dir: "@output_dir" - _target_: Lambdad keys: pred func: "$lambda x: print(@prediction_shape + str(x.shape))"