Skip to content

Commit

Permalink
Add feature flag, and docs and improve tests
Browse files Browse the repository at this point in the history
Signed-off-by: John Zielke <j.l.zielke@gmail.com>
  • Loading branch information
johnzielke committed Apr 30, 2024
1 parent 9ef880c commit 16f579f
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 41 deletions.
56 changes: 56 additions & 0 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,62 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

### Wrapping config components
**EXPERIMENTAL FEATURE**

Sometimes it can be necessary to wrap (i.e. decorate) a component in the config without
shifting the configuration tree one level down.
Take the following configuration as an example:

```json
{
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64]
}
}
```
If we wanted to use `torch.compile` to speed up the model, we would have to write a configuration like this:

```json
{
"model": {
"_target_": "torch::jit::compile",
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64]
}
}
}
```
This means we now need to adjust all references to parameters like `model.spatial_dims` to `model.model.spatial_dims`
throughout our code and configuration.
To avoid this, we can use the `_wrapper_` key to wrap the model in the configuration:

```json
{
"model": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64],
"_wrapper_": {
"_target_": "torch::jit::compile",
"_mode_": "callable"
}
}
}
```

Note that when accessing `@model` in the configuration, the model object will be the compiled model now.

## The command line interface

In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.
Expand Down
75 changes: 49 additions & 26 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
from monai.utils import CompInitMode, ensure_tuple, first, instantiate, optional_import, run_debug, run_eval

__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent", "Instantiable"]
CONFIG_COMPONENT_KEY_WRAPPER = "_wrapper_"

from monai.utils.feature_flag import FeatureFlag

CONFIG_COMPONENT_KEY_MODE = "_mode_"
CONFIG_COMPONENT_KEY_DESC = "_desc_"
CONFIG_COMPONENT_KEY_REQUIRES = "_requires_"
CONFIG_COMPONENT_KEY_DISABLED = "_disabled_"
CONFIG_COMPONENT_KEY_TARGET = "_target_"
CONFIG_COMPONENT_KEY_WRAPPER = "_wrapper_"

_wrapper_feature_flag = FeatureFlag("CONFIG_WRAPPER", default=False)


class Instantiable(ABC):
Expand Down Expand Up @@ -172,7 +177,7 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.
Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
Currently, four special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
- class or function identifier of the python module, specified by ``"_target_"``,
indicating a monai built-in Python class or function such as ``"LoadImageDict"``,
Expand All @@ -189,6 +194,12 @@ class ConfigComponent(ConfigItem, Instantiable):
- ``"default"``: returns ``component(**kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
- ``"_wrapper_"`` (optional): a callable that wraps the instantiation of the component.
This feature is currently experimental and hidden behind a feature flag. To enable it, set the
environment variable ``MONAI_FEATURE_ENABLED_CONFIG_WRAPPER=1`` or
call monai.bundle.config_item._wrapper_feature_flag.enable().
The callable should take the instantiated component as input and return the wrapped component.
A use case of this can be torch.compile(). See the Config Guide for more details.
Other fields in the config content are input arguments to the python module.
Expand Down Expand Up @@ -275,7 +286,11 @@ def resolve_args(self):
Utility function used in `instantiate()` to resolve the arguments from current config content.
"""
return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys}
return {
k: v
for k, v in self.get_config().items()
if (k not in self.non_arg_keys) or (k == CONFIG_COMPONENT_KEY_WRAPPER and not _wrapper_feature_flag.enabled)
}

def is_disabled(self) -> bool:
"""
Expand All @@ -285,7 +300,7 @@ def is_disabled(self) -> bool:
_is_disabled = self.get_config().get(CONFIG_COMPONENT_KEY_DISABLED, False)
return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled)

def get_wrapper(self) -> None | Callable[[object], object]:
def _get_wrapper(self) -> None | Callable[[object], object]:
"""
Utility function used in `instantiate()` to check whether to skip the instantiation.
Expand All @@ -298,7 +313,7 @@ def instantiate(self, **kwargs: Any) -> object:
The target component must be a `class` or a `function`, otherwise, return `None`.
Args:
kwargs: args to override / add the config args when instantiation.
kwargs: instantiate_kwargs to override / add the config instantiate_kwargs when instantiation.
"""
if not self.is_instantiable(self.get_config()) or self.is_disabled():
Expand All @@ -307,27 +322,35 @@ def instantiate(self, **kwargs: Any) -> object:

modname = self.resolve_module_name()
mode = self.get_config().get(CONFIG_COMPONENT_KEY_MODE, CompInitMode.DEFAULT)
args = self.resolve_args()
args.update(kwargs)
wrapper = self.get_wrapper()
if wrapper is not None:
if callable(wrapper):
return wrapper(instantiate(modname, mode, **args))
else:
raise ValueError(
f"wrapper must be a callable, but got: {wrapper}, type {type(wrapper)}."
f"make sure all references are resolved before calling instantiate"
)
if self.get_id().endswith(CONFIG_COMPONENT_KEY_WRAPPER):
try:
return instantiate(modname, mode, **args)
except Exception as e:
raise RuntimeError(
f"Failed to instantiate {self}. Make sure you are returning a partial "
f"(you might need to add {CONFIG_COMPONENT_KEY_MODE}:callable, "
f"especially when using specifying a class)."
) from e
return instantiate(modname, mode, **args)
instantiate_kwargs = self.resolve_args()
instantiate_kwargs.update(kwargs)
wrapper = self._get_wrapper()
if _wrapper_feature_flag.enabled:
if wrapper is not None:
if callable(wrapper):
return wrapper(instantiate(modname, mode, **instantiate_kwargs))
else:
raise ValueError(
f"wrapper must be a callable, but got type {type(wrapper)}: {wrapper}."
"make sure all references are resolved before calling instantiate "
"and the wrapper is a callable."
)
if self.get_id().endswith(CONFIG_COMPONENT_KEY_WRAPPER):
try:
return instantiate(modname, mode, **instantiate_kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to instantiate {self}. Make sure you are returning a partial "
f"(you might need to add {CONFIG_COMPONENT_KEY_MODE}:callable, "
"especially when using specifying a class)."
) from e
elif wrapper is not None:
warnings.warn(
f"ConfigComponent: {self.get_id()} has a key {CONFIG_COMPONENT_KEY_WRAPPER}. "
"Since the feature flag CONFIG_WRAPPER is not enabled, the key will be treated as a normal config key. "
"In future versions of MONAI, this key might be reserved for the wrapper functionality."
)
return instantiate(modname, mode, **instantiate_kwargs)


class ConfigExpression(ConfigItem):
Expand Down
61 changes: 61 additions & 0 deletions monai/utils/feature_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import os
from contextlib import contextmanager

FEATURE_FLAG_PREFIX = "MONAI_FEATURE_ENABLED_"


class FeatureFlag:
def __init__(self, name: str, *, default: bool = False):
self.name = name
self._enabled = None
self.default = default

def _get_from_env(self):
return os.getenv(FEATURE_FLAG_PREFIX + self.name, None)

@property
def enabled(self):
if self._enabled is None:
env = self._get_from_env()
if env is None:
self._enabled = self.default
else:
self._enabled = env.lower() in ["true", "1", "yes"]
return self._enabled

@enabled.setter
def enabled(self, value: bool):
self._enabled = value

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

def __str__(self):
return f"{self.name}: {self.enabled}, default: {self.default}"


@contextmanager
def with_feature_flag(feature_flag: FeatureFlag, enabled: bool):
original = feature_flag.enabled
feature_flag.enabled = enabled
try:
yield
finally:
feature_flag.enabled = original
47 changes: 32 additions & 15 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
import tempfile
import unittest
import warnings
from collections import OrderedDict
from pathlib import Path
from unittest import mock, skipUnless

import numpy as np
from parameterized import parameterized

from monai.bundle import ConfigParser, ReferenceResolver
from monai.bundle.config_item import ConfigItem
from monai.bundle.config_item import CONFIG_COMPONENT_KEY_WRAPPER, ConfigItem, _wrapper_feature_flag
from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
from monai.utils.feature_flag import with_feature_flag
from tests.utils import TimedCall

_, has_tv = optional_import("torchvision", "0.8.0", min_version)
Expand Down Expand Up @@ -124,16 +126,29 @@ def __call__(self, a, b):
1,
[0, 4],
]
TEST_CASE_WRAPPER = [
TEST_CASE_WRAPPER_ENABLED = [
{
"dataset": {
"_target_": "Dataset",
"data": [1, 2],
"_wrapper_": {"_target_": "CacheDataset", "_mode_": "callable"},
CONFIG_COMPONENT_KEY_WRAPPER: {"_target_": "CacheDataset", "_mode_": "callable"},
}
},
["dataset"],
[CacheDataset],
["dataset", f"dataset#{CONFIG_COMPONENT_KEY_WRAPPER}"],
[CacheDataset, type(CacheDataset)],
True,
]
TEST_CASE_WRAPPER_DISABLED = [
{
"dataset": {
"_target_": "collections.OrderedDict",
"data": [1, 2],
CONFIG_COMPONENT_KEY_WRAPPER: {"_target_": "CacheDataset", "_mode_": "callable"},
}
},
["dataset", f"dataset#{CONFIG_COMPONENT_KEY_WRAPPER}"],
[OrderedDict, type(CacheDataset)],
False,
]


Expand Down Expand Up @@ -368,16 +383,18 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex
self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val)
self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals)

@parameterized.expand([TEST_CASE_WRAPPER])
def test_parse_wrapper(self, config, expected_ids, output_types):
parser = ConfigParser(config=config, globals={"monai": "monai", "torch": "torch"})

for id, cls in zip(expected_ids, output_types):
self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
# test root content
root = parser.get_parsed_content(id="")
for v, cls in zip(root.values(), output_types):
self.assertTrue(isinstance(v, cls))
@parameterized.expand([TEST_CASE_WRAPPER_ENABLED, TEST_CASE_WRAPPER_DISABLED])
def test_parse_wrapper(self, config, expected_ids, output_types, enable_feature_flag):
with with_feature_flag(_wrapper_feature_flag, enable_feature_flag):
parser = ConfigParser(config=config, globals={"monai": "monai", "torch": "torch"})
for id, cls in zip(expected_ids, output_types):
self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
# test root content
root = parser.get_parsed_content(id="")
for v, cls in zip(root.values(), output_types):
self.assertTrue(isinstance(v, cls))
if not enable_feature_flag:
assert CONFIG_COMPONENT_KEY_WRAPPER in root["dataset"]


if __name__ == "__main__":
Expand Down

0 comments on commit 16f579f

Please sign in to comment.