Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
894d9b6
Merge pull request #385 from Project-MONAI/dev
Nic-Ma Mar 14, 2022
a9712a6
[DLMED] add verify script
Nic-Ma Mar 14, 2022
b2dcac7
[DLMED] fix typo
Nic-Ma Mar 14, 2022
00069e0
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 15, 2022
2aa1b3f
[DLMED] update according to comments
Nic-Ma Mar 15, 2022
6292b52
[DLMED] add unit tests and doc
Nic-Ma Mar 15, 2022
6ceb6a0
[DLMED] fix flake8
Nic-Ma Mar 15, 2022
939b081
[DLMED] skip min tests
Nic-Ma Mar 15, 2022
ead76cd
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 15, 2022
c48e476
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 15, 2022
1fea761
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 15, 2022
3e8b4c3
[DLMED] remove doc-string
Nic-Ma Mar 16, 2022
2084db1
[DLMED] fix typo
Nic-Ma Mar 16, 2022
0f250bc
[DLMED] update device names
Nic-Ma Mar 16, 2022
6ddc1b0
[DLMED] update doc-string examples
Nic-Ma Mar 16, 2022
7936559
[DLMED] enhance error message
Nic-Ma Mar 16, 2022
b2947ea
[DLMED] cpu:0 to cpu
Nic-Ma Mar 16, 2022
a348403
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 16, 2022
d2727e7
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 16, 2022
7f4055a
[DLMED] adjust "dataset_dir"
Nic-Ma Mar 17, 2022
09a0acf
Merge branch 'dev' into 3482-verify-net-args
Nic-Ma Mar 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ Model Bundle
---------
.. autofunction:: run
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
1 change: 0 additions & 1 deletion docs/source/mb_specification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ An example JSON metadata file:
"copyright": "Copyright (c) MONAI Consortium",
"data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
"data_type": "dicom",
"dataset_dir": "/workspace/data/Task09_Spleen",
"image_classes": "single channel data, intensity scaled to [0, 1]",
"label_classes": "single channel data, 1 is spleen, 0 is everything else",
"pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
from .scripts import run, verify_metadata
from .scripts import run, verify_metadata, verify_net_in_out
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
2 changes: 1 addition & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from monai.bundle.scripts import run, verify_metadata
from monai.bundle.scripts import run, verify_metadata, verify_net_in_out

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
148 changes: 142 additions & 6 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import pprint
import re
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import torch

from monai.apps.utils import download_url, get_logger
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import check_parent_dir, optional_import
from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Expand Down Expand Up @@ -51,13 +54,54 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr
return args_


def _log_input_summary(tag: str, args: Dict):
logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---")
def _log_input_summary(tag, args: Dict):
logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
for name, val in args.items():
logger.info(f"> {name}: {pprint.pformat(val)}")
logger.info("---\n\n")


def _get_var_names(expr: str):
"""
Parse the expression and discover what variables are present in it based on ast module.

Args:
expr: source expression to parse.

"""
tree = ast.parse(expr)
return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)]


def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple:
"""
Get spatial shape for fake data according to the specified shape pattern.
It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n".

Args:
shape: specified pattern for the spatial shape.
p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.

"""
ret = []
for i in shape:
if isinstance(i, int):
ret.append(i)
elif isinstance(i, str):
if i == "*":
ret.append(any)
else:
for c in _get_var_names(i):
if c not in ["p", "n"]:
raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.")
ret.append(eval(i, {"p": p, "n": n}))
else:
raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.")
return tuple(ret)


def run(
runner_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -94,8 +138,8 @@ def run(
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`runner_id` and override pairs. so that the command line inputs can be simplified.
args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`,
`config_file`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.

Expand Down Expand Up @@ -172,3 +216,95 @@ def verify_metadata(
logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.")
return
logger.info("metadata is verified with no error.")


def verify_net_in_out(
net_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
config_file: Optional[Union[str, Sequence[str]]] = None,
device: Optional[str] = None,
p: Optional[int] = None,
n: Optional[int] = None,
any: Optional[int] = None,
args_file: Optional[str] = None,
**override,
):
"""
Verify the input and output data shape and data type of network defined in the metadata.
Will test with fake Tensor data according to the required data shape in `metadata`.

Typical usage examples:

.. code-block:: bash

python -m monai.bundle verify_net_in_out network --meta_file <meta path> --config_file <config path>

Args:
net_id: ID name of the network component to verify, it must be `torch.nn.Module`.
meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
device: target device to run the network forward computation, if None, prefer to "cuda" if existing.
p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`net_id` and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

"""

_args = _update_args(
args=args_file,
net_id=net_id,
meta_file=meta_file,
config_file=config_file,
device=device,
p=p,
n=n,
any=any,
**override,
)
_log_input_summary(tag="verify_net_in_out", args=_args)

parser = ConfigParser()
parser.read_config(f=_args.pop("config_file"))
parser.read_meta(f=_args.pop("meta_file"))
id = _args.pop("net_id", "")
device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu"))
p = _args.pop("p", 1)
n = _args.pop("n", 1)
any = _args.pop("any", 1)

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

try:
key: str = id # mark the full id when KeyError
net = parser.get_parsed_content(key).to(device_)
key = "_meta_#network_data_format#inputs#image#num_channels"
input_channels = parser[key]
key = "_meta_#network_data_format#inputs#image#spatial_shape"
input_spatial_shape = tuple(parser[key])
key = "_meta_#network_data_format#inputs#image#dtype"
input_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
key = "_meta_#network_data_format#outputs#pred#num_channels"
output_channels = parser[key]
key = "_meta_#network_data_format#outputs#pred#dtype"
output_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
except KeyError as e:
raise KeyError(f"Failed to verify due to missing expected key in the config: {key}.") from e

net.eval()
with torch.no_grad():
spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore
test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_)
output = net(test_data)
if output.shape[1] != output_channels:
raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.")
if output.dtype != output_dtype:
raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.")
logger.info("data shape of network is verified with no error.")
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def run_testsuit():
"test_prepare_batch_default_dist",
"test_parallel_execution_dist",
"test_bundle_verify_metadata",
"test_bundle_verify_net",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
4 changes: 2 additions & 2 deletions tests/test_bundle_verify_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_verify(self, meta_file, schema_file):
def_args_file = os.path.join(tempdir, "def_args.json")
ConfigParser.export_config_file(config=def_args, filepath=def_args_file)

hash_val = "b11acc946148c0186924f8234562b947"
hash_val = "e3a7e23d1113a1f3e6c69f09b6f9ce2c"

cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
cmd += ["--filepath", schema_file, "--hash_val", hash_val, "--args_file", def_args_file]
Expand All @@ -54,7 +54,7 @@ def test_verify_error(self):
json.dump(
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/"
"download/0.8.1/meta_schema_202203130950.json",
"download/0.8.1/meta_schema_202203171008.json",
"wrong_meta": "wrong content",
},
f,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_bundle_verify_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import sys
import tempfile
import unittest

from parameterized import parameterized

from monai.bundle import ConfigParser
from tests.utils import skip_if_windows

TEST_CASE_1 = [
os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"),
os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"),
]


@skip_if_windows
class TestVerifyNetwork(unittest.TestCase):
@parameterized.expand([TEST_CASE_1])
def test_verify(self, meta_file, config_file):
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2}
def_args_file = os.path.join(tempdir, "def_args.json")
ConfigParser.export_config_file(config=def_args, filepath=def_args_file)

cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file]
cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file]
cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"]
ret = subprocess.check_call(cmd)
self.assertEqual(ret, 0)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/testing_data/inference.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"dataset_dir": "/workspace/data/Task09_Spleen",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"network_def": {
"_target_": "UNet",
Expand Down
1 change: 1 addition & 0 deletions tests/testing_data/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
dataset_dir: "/workspace/data/Task09_Spleen"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
network_def:
_target_: UNet
Expand Down
3 changes: 1 addition & 2 deletions tests/testing_data/metadata.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203130950.json",
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json",
"version": "0.1.0",
"changelog": {
"0.1.0": "complete the model package",
Expand All @@ -17,7 +17,6 @@
"copyright": "Copyright (c) MONAI Consortium",
"data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
"data_type": "dicom",
"dataset_dir": "/workspace/data/Task09_Spleen",
"image_classes": "single channel data, intensity scaled to [0, 1]",
"label_classes": "single channel data, 1 is spleen, 0 is everything else",
"pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
Expand Down