Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
ec463d6
Merge pull request #397 from Project-MONAI/dev
Nic-Ma Apr 4, 2022
ca62306
Merge pull request #429 from Project-MONAI/dev
Nic-Ma Jul 8, 2022
7500a6a
Merge pull request #454 from Project-MONAI/dev
Nic-Ma Dec 6, 2022
1444973
[DLMED] save config as string param
Nic-Ma Dec 6, 2022
15d23d6
Merge branch 'dev' into 5648-track-bundle-config
Nic-Ma Dec 6, 2022
5340d1a
Merge branch 'dev' into 5648-track-bundle-config
Nic-Ma Dec 7, 2022
11d94a5
[DLMED] update according to comments
Nic-Ma Dec 7, 2022
46e9bc4
[DLMED] update according to comments
Nic-Ma Dec 7, 2022
a87d789
[DLMED] fix test
Nic-Ma Dec 7, 2022
1f35051
[DLMED] fix FL tests
Nic-Ma Dec 7, 2022
73ddbf1
Merge branch 'dev' into 5648-track-bundle-config
Nic-Ma Dec 7, 2022
2a74389
[DLMED] fix wrong path
Nic-Ma Dec 7, 2022
6dfce38
[DLMED] update according to the comments
Nic-Ma Dec 7, 2022
51cf7d3
Merge branch 'dev' into 5648-track-bundle-config
wyli Dec 7, 2022
2089dfb
Merge branch 'dev' into 5648-track-bundle-config
wyli Dec 7, 2022
f3ab12a
[DLMED] fix windows error and simplify utility
Nic-Ma Dec 8, 2022
4e14592
Merge branch 'dev' into 5648-track-bundle-config
Nic-Ma Dec 8, 2022
7c38fec
[DLMED] add absolute path
Nic-Ma Dec 8, 2022
00b9cc1
[DLMED] add default output_dir
Nic-Ma Dec 8, 2022
91bcf4d
Merge branch 'dev' into 5648-track-bundle-config
wyli Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,23 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str:
result = cls.id_matcher.findall(value)
value_is_expr = ConfigExpression.is_expression(value)
for item in result:
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
if ref_id not in refs:
msg = f"can not find expected ID '{ref_id}' in the references."
if cls.allow_missing_reference:
warnings.warn(msg)
continue
else:
raise KeyError(msg)
if value_is_expr:
# replace with local code, `{"__local_refs": self.resolved_content}` will be added to
# the `globals` argument of python `eval` in the `evaluate`
value = value.replace(item, f"{cls._vars}['{ref_id}']")
elif value == item:
# the whole content is "@XXX", it will avoid the case that regular string contains "@"
value = refs[ref_id]
# only update reference when string starts with "$" or the whole content is "@XXX"
if value_is_expr or value == item:
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
if ref_id not in refs:
msg = f"can not find expected ID '{ref_id}' in the references."
if cls.allow_missing_reference:
warnings.warn(msg)
continue
else:
raise KeyError(msg)
if value_is_expr:
# replace with local code, `{"__local_refs": self.resolved_content}` will be added to
# the `globals` argument of python `eval` in the `evaluate`
value = value.replace(item, f"{cls._vars}['{ref_id}']")
elif value == item:
# the whole content is "@XXX", it will avoid the case that regular string contains "@"
value = refs[ref_id]
return value

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pprint
import re
import time
import warnings
from logging.config import fileConfig
from pathlib import Path
Expand Down Expand Up @@ -489,6 +490,18 @@ def patch_bundle_tracking(parser: ConfigParser, settings: dict):
handlers.append(v)
elif k not in parser:
parser[k] = v
# save the executed config into file
default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json"
filepath = parser.get("execute_config", None)
if filepath is None:
if "output_dir" not in parser:
# if no "output_dir" in the bundle config, default to "<bundle root>/eval"
parser["output_dir"] = "$@bundle_root + '/eval'"
# experiment management tools can refer to this config item to track the config info
parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'"
filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name)
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
parser.export_config_file(parser.get(), filepath)


def run(
Expand Down
9 changes: 8 additions & 1 deletion monai/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,14 @@
DEFAULT_MLFLOW_SETTINGS = {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"tracking_uri": "$@output_dir + '/mlruns'",
# if no "output_dir" in the bundle config, default to "<bundle root>/eval"
"output_dir": "$@bundle_root + '/eval'",
# use URI to support linux, mac and windows os
"tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlruns'",
"experiment_name": "monai_experiment",
"run_name": None,
# may fill it at runtime
"execute_config": None,
"is_not_rank0": (
"$torch.distributed.is_available() \
and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"
Expand All @@ -118,6 +123,7 @@
"tracking_uri": "@tracking_uri",
"experiment_name": "@experiment_name",
"run_name": "@run_name",
"artifacts": "@execute_config",
"iteration_log": True,
"epoch_log": True,
"tag_name": "train_loss",
Expand All @@ -140,6 +146,7 @@
"tracking_uri": "@tracking_uri",
"experiment_name": "@experiment_name",
"run_name": "@run_name",
"artifacts": "@execute_config",
"iteration_log": False,
"close_on_complete": True,
},
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_scalar_tensor,
issequenceiterable,
list_to_dict,
path_to_uri,
progress_bar,
sample_slices,
save_obj,
Expand Down
12 changes: 12 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"check_parent_dir",
"save_obj",
"label_union",
"path_to_uri",
]

_seed = None
Expand Down Expand Up @@ -584,3 +585,14 @@ def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs):
threshold: threshold value to activate the sigmoid function.
"""
return torch.argmax(x, **kwargs) if not sigmoid else (x > threshold).int()


def path_to_uri(path: PathLike) -> str:
"""
Convert a file path to URI. if not absolute path, will convert to absolute path first.

Args:
path: input file path to convert, can be a string or `Path` object.

"""
return Path(path).absolute().as_uri()
8 changes: 5 additions & 3 deletions tests/test_fl_monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import shutil
import tempfile
import unittest
from pathlib import Path

from parameterized import parameterized

Expand All @@ -22,6 +21,7 @@
from monai.fl.client.monai_algo import MonaiAlgo
from monai.fl.utils.constants import ExtraItems
from monai.fl.utils.exchange_object import ExchangeObject
from monai.utils import path_to_uri
from tests.utils import SkipIfNoModule

_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
Expand Down Expand Up @@ -151,12 +151,13 @@ def test_train(self, input_params):
input_params["tracking"] = {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"execute_config": f"{data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": Path(data_dir).as_uri() + "/mlflow_override",
"tracking_uri": path_to_uri(data_dir) + "/mlflow_override",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
}
},
},
}

Expand All @@ -177,6 +178,7 @@ def test_train(self, input_params):
algo.train(data=data, extra={})
algo.finalize()
self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override"))
self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json"))
shutil.rmtree(data_dir)

@parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import os
import tempfile
import unittest
from pathlib import Path

import numpy as np
from ignite.engine import Engine, Events

from monai.handlers import MLFlowHandler
from monai.utils import path_to_uri


class TestHandlerMLFlow(unittest.TestCase):
Expand Down Expand Up @@ -49,7 +49,7 @@ def _update_metric(engine):
handler = MLFlowHandler(
iteration_log=False,
epoch_log=True,
tracking_uri=Path(test_path).as_uri(),
tracking_uri=path_to_uri(test_path),
state_attributes=["test"],
experiment_param=experiment_param,
artifacts=[artifact_path],
Expand Down
15 changes: 8 additions & 7 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
import tempfile
import unittest
from pathlib import Path
from glob import glob

import nibabel as nib
import numpy as np
Expand All @@ -24,6 +24,7 @@
from monai.bundle import ConfigParser
from monai.bundle.utils import DEFAULT_HANDLERS_ID
from monai.transforms import LoadImage
from monai.utils import path_to_uri
from tests.utils import command_line_tests

TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)]
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_shape(self, config_file, expected_shape):
"no_epoch": True, # test override config in the settings file
"evaluator": {
"_target_": "MLFlowHandler",
"tracking_uri": "$@output_dir + '/mlflow_override1'",
"tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlflow_override1'",
"iteration_log": "@no_epoch",
},
},
Expand All @@ -105,14 +106,12 @@ def test_shape(self, config_file, expected_shape):
json.dump("Dataset", f)

if sys.platform == "win32":
outdir = Path(tempdir).as_uri()
override = "--network $@network_def.to(@device) --dataset#_target_ Dataset"
else:
outdir = tempdir
override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}"
# test with `monai.bundle` as CLI entry directly
cmd = "-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg"
cmd += f" {override} --no_epoch False --save_dir {tempdir} --output_dir {outdir}"
cmd += f" {override} --no_epoch False --output_dir {tempdir}"
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
test_env = os.environ.copy()
print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
Expand All @@ -121,14 +120,16 @@ def test_shape(self, config_file, expected_shape):
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)
self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override1"))

tracking_uri = outdir + "/mlflow_override2" # test override experiment management configs
tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs
# here test the script with `google fire` tool as CLI
cmd = "-m fire monai.bundle.scripts run --runner_id evaluating --tracking mlflow --evaluator#amp False"
cmd += f" --tracking_uri {tracking_uri} {override} --save_dir {tempdir} --output_dir {outdir}"
cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir}"
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
command_line_tests(la)
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape)
self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override2"))
# test the saved execution configs
self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions tests/testing_data/inference.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"dataset_dir": "/workspace/data/Task09_Spleen",
"save_dir": "need_override",
"output_dir": "need override",
"prediction_shape": "prediction shape:",
"import_glob": "$import glob",
Expand Down Expand Up @@ -89,7 +88,7 @@
{
"_target_": "SaveImaged",
"keys": "pred",
"output_dir": "@save_dir"
"output_dir": "@output_dir"
},
{
"_target_": "Lambdad",
Expand Down
3 changes: 1 addition & 2 deletions tests/testing_data/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
dataset_dir: "/workspace/data/Task09_Spleen"
save_dir: "need_override"
output_dir: "need override"
prediction_shape: "prediction shape:"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
Expand Down Expand Up @@ -65,7 +64,7 @@ postprocessing:
argmax: true
- _target_: SaveImaged
keys: pred
output_dir: "@save_dir"
output_dir: "@output_dir"
- _target_: Lambdad
keys: pred
func: "$lambda x: print(@prediction_shape + str(x.shape))"
Expand Down