Skip to content

Commit

Permalink
Add config_only_args arg to alf.configurable to prevent local config …
Browse files Browse the repository at this point in the history
…changes (#1656)
  • Loading branch information
QuantuMope committed Jun 3, 2024
1 parent 9f19a72 commit 2824770
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 13 deletions.
74 changes: 61 additions & 13 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _ensure_wrappability(fn):
return fn


def _make_wrapper(fn, configs, signature, has_self):
def _make_wrapper(fn, configs, signature, has_self, config_only_args):
"""Wrap the function.
Args:
Expand All @@ -527,6 +527,11 @@ def _make_wrapper(fn, configs, signature, has_self):
has_self (bool): whether the first argument is expected to be self but
signature does not contains parameter for self. This should be True
if fn is __init__() function of a class.
config_only_args (list[str]): list of args that should be guarded. In other
words, their values can only be set / changed globally via ``alf.config()``.
This protects against local untracked changes as a result of 1) using
``partial()`` or 2) setting an argument when calling the function, which
can cause unintended side effects.
Returns:
The wrapped function
"""
Expand All @@ -538,12 +543,13 @@ def _wrapper(*args, **kwargs):
num_positional_args = len(args)
num_positional_args -= has_self

set_positional_args = []
for i, (name, param) in enumerate(signature.parameters.items()):
config = configs.get(name, None)
if config is None:
continue
elif i < num_positional_args:
continue
set_positional_args.append(name)
elif param.kind in (Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD):
continue
Expand All @@ -557,13 +563,20 @@ def _wrapper(*args, **kwargs):
unspecified_kw_args[name] = config.get_value()
config.set_used()

for config_only_arg in config_only_args:
if config_only_arg in set_positional_args or config_only_arg in kwargs:
raise ValueError(
f"The arg '{config_only_arg}' of {fn.__qualname__} is guarded but has been modified. "
f"Most likely partial() was used to change this value, which is not allowed."
)

return fn(*args, *unspecified_positional_args, **kwargs,
**unspecified_kw_args)

return _wrapper


def _decorate(fn_or_cls, name, whitelist, blacklist):
def _decorate(fn_or_cls, name, whitelist, blacklist, config_only_args):
"""decorate a function or class.
Args:
Expand All @@ -576,6 +589,11 @@ def _decorate(fn_or_cls, name, whitelist, blacklist):
blacklist (list[str]): A blacklisted set of kwargs that should not be
configurable. All other kwargs will be configurable. Only one of
``whitelist` or ``blacklist`` should be specified.
config_only_args (list[str]): list of args that should be guarded. In other
words, their values can only be set / changed globally via ``alf.config()``.
This protects against local untracked changes as a result of 1) using
``partial()`` or 2) setting an argument when calling the function, which
can cause unintended side effects.
Returns:
The decorated function
"""
Expand Down Expand Up @@ -604,12 +622,17 @@ def _decorate(fn_or_cls, name, whitelist, blacklist):
has_self = construction_fn.__name__ != '__new__'
decorated_fn = _make_wrapper(
_ensure_wrappability(construction_fn), configs, signature,
has_self)
has_self, config_only_args)
if construction_fn.__name__ == '__new__':
decorated_fn = staticmethod(decorated_fn)
setattr(fn_or_cls, construction_fn.__name__, decorated_fn)
else:
fn_or_cls = _make_wrapper(fn_or_cls, configs, signature, has_self=0)
fn_or_cls = _make_wrapper(
fn_or_cls,
configs,
signature,
has_self=0,
config_only_args=config_only_args)

if fn_or_cls.__module__ != '<run_path>' and os.environ.get(
'ALF_USE_GIN', "1") == "1":
Expand Down Expand Up @@ -691,7 +714,10 @@ def _wrapper(*args, **kwargs):
return cls


def configurable(fn_or_name=None, whitelist=[], blacklist=[]):
def configurable(fn_or_name=None,
whitelist=[],
blacklist=[],
config_only_args=[]):
"""Decorator to make a function or class configurable.
This decorator registers the decorated function/class as configurable, which
Expand All @@ -703,7 +729,11 @@ def configurable(fn_or_name=None, whitelist=[], blacklist=[]):
If some parameters should not be configurable, they can be specified in
``blacklist``. If only a restricted set of parameters should be configurable,
they can be specified in ``whitelist``.
they can be specified in ``whitelist``. Furthermore, parameters can be
guarded by being specified in ``config_only_args`` so that their values can only be
changed globally via alf.config(). This prevents unintended side effects
that may arise from having inconsistent parameter values caused by local
changes (e.g., partial()).
The decorator can be used without any parameters as follows:
Expand Down Expand Up @@ -772,7 +802,7 @@ def Test(arg):
using gin. The values specified using ``alf.config()`` will override
values specified through gin. Gin wrapper is quite convoluted and can make
debugging more challenging. It can be disabled by setting environment
varialbe ALF_USE_GIN to 0 if you are not using gin.
variable ALF_USE_GIN to 0 if you are not using gin.
Args:
fn_or_name (Callable|str): A name for this configurable, or a function
Expand All @@ -787,13 +817,23 @@ def Test(arg):
``blacklist`` should be specified.
blacklist (list[str]): A blacklisted set of kwargs that should not be
configurable. All other kwargs will be configurable. Only one of
``whitelist`` or ``blacklist`` should be specified.
``whitelist`` or ``blacklist`` should be specified. An entry that is in
``blacklist`` cannot be in ``config_only_args``.
config_only_args (list[str]): list of args that should be guarded. In other
words, their values can only be set / changed globally via ``alf.config()``.
This protects against local untracked changes as a result of 1) using
``partial()`` or 2) setting an argument when calling the function, which
can cause unintended side effects. An entry that is in ``config_only_args``
cannot be in ``blacklist``.
Returns:
decorated function if fn_or_name is Callable.
a decorator if fn is not Callable.
Raises:
ValueError: If a configurable with ``name`` (or the name of `fn_or_cls`)
already exists, or if both a whitelist and blacklist are specified.
ValueError: Can be raised
1) If a configurable with ``name`` (or the name of `fn_or_cls`) already exists
2) If both a whitelist and blacklist are specified.
3) If the same entry is found in both blacklist and config_only_args.
4) If an arg listed in config_only_args is changed without using alf.config().
"""

if callable(fn_or_name):
Expand All @@ -804,14 +844,22 @@ def Test(arg):
if whitelist and blacklist:
raise ValueError("Only one of 'whitelist' and 'blacklist' can be set.")

for entry in blacklist:
if entry in config_only_args:
raise ValueError(
f"Entry '{entry}' found in both blacklist and config_only_args. "
f"An entry can only be in one of these lists.")

if not callable(fn_or_name):

def _decorator(fn_or_cls):
return _decorate(fn_or_cls, name, whitelist, blacklist)
return _decorate(fn_or_cls, name, whitelist, blacklist,
config_only_args)

return _decorator
else:
return _decorate(fn_or_name, name, whitelist, blacklist)
return _decorate(fn_or_name, name, whitelist, blacklist,
config_only_args)


def define_config(name, default_value):
Expand Down
56 changes: 56 additions & 0 deletions alf/config_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from absl import logging
from functools import partial
import os
import pprint
import tempfile
Expand Down Expand Up @@ -173,6 +174,61 @@ def test_load_config(self):
os.path.exists(
os.path.join(temp_dir, "configs", "base", "base_conf.py"))

def test_config_only_args(self):
@alf.configurable(config_only_args=['y'])
def func_test(y=0, z=0):
pass

@alf.configurable(config_only_args=['y'])
class TestClass:
def __init__(self, y=0, z=0):
pass

@alf.configurable(config_only_args=['y'], whitelist=['y'])
class TestClassWhiteList:
def __init__(self, y=0, z=0):
pass

test_callables = [func_test, TestClass, TestClassWhiteList]

for test_callable in test_callables:
test_callable()
test_callable(z=1)
with self.assertRaises(ValueError) as context:
test_callable(y=1)
test_callable_partial1 = partial(test_callable, z=1)
test_callable_partial1()
test_callable_partial2 = partial(test_callable, y=1)
with self.assertRaises(ValueError) as context:
test_callable_partial2()

with self.assertRaises(ValueError) as context:

@alf.configurable(config_only_args=['y'], blacklist=['y'])
class TestClassBlackList:
def __init__(self, y=0, z=0):
pass

@alf.configurable(config_only_args=['x'])
class TestPositionalArgs:
def __init__(self, x, y, z=0):
pass

with self.assertRaises(ValueError) as context:
TestPositionalArgs(0, 0)
test_partial = partial(TestPositionalArgs, x=0)
with self.assertRaises(ValueError) as context:
test_partial(y=0)
test_partial = partial(TestPositionalArgs, y=0)
with self.assertRaises(ValueError) as context:
test_partial(0)
with self.assertRaises(ValueError) as context:
test_partial(x=0)
alf.config("TestPositionalArgs", x=0)
TestPositionalArgs(y=0)
test_partial = partial(TestPositionalArgs, y=0)
test_partial()


if __name__ == '__main__':
alf.test.main()

0 comments on commit 2824770

Please sign in to comment.