diff --git a/CHANGELOG.md b/CHANGELOG.md index 56f832b1..9ee49264 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ## Unreleased ### Changed +* (testing) Added `@flagsaver.as_parsed`: this allows saving/restoring flags + using string values as if parsed from the command line and will also reflect + other flag states after command line parsing, e.g. `.present` is set. * (logging) If no log dir is specified `logging.find_log_dir()` now falls back to `tempfile.gettempdir()` instead of `/tmp/`. diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 6661b783..fd0e6310 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -792,8 +792,10 @@ def get_value(): continue if flag is not None: + # LINT.IfChange flag.parse(value) flag.using_default_value = False + # LINT.ThenChange(../testing/flagsaver.py:flag_override_parsing) else: unparsed_names_and_args.append((name, arg)) diff --git a/absl/testing/BUILD b/absl/testing/BUILD index d4287926..3173c4b9 100644 --- a/absl/testing/BUILD +++ b/absl/testing/BUILD @@ -212,6 +212,7 @@ py_test( deps = [ ":absltest", ":flagsaver", + ":parameterized", "//absl/flags", ], ) diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py index 774c698c..e96c8c52 100644 --- a/absl/testing/flagsaver.py +++ b/absl/testing/flagsaver.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Decorator and context manager for saving and restoring flag values. There are many ways to save and restore. Always use the most convenient method @@ -49,6 +50,36 @@ def some_func(): finally: flagsaver.restore_flag_values(saved_flag_values) + # Use the parsing version to emulate users providing the flags. + # Note that all flags must be provided as strings (unparsed). + @flagsaver.as_parsed(some_int_flag='123') + def some_func(): + # Because the flag was parsed it is considered "present". + assert FLAGS.some_int_flag.present + do_stuff() + + # flagsaver.as_parsed() can also be used as a context manager just like + # flagsaver.flagsaver() + with flagsaver.as_parsed(some_int_flag='123'): + do_stuff() + + # The flagsaver.as_parsed() interface also supports FlagHolder objects. + @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23')) + def some_func(): + do_stuff() + + # Using as_parsed with a multi_X flag requires a sequence of strings. + @flagsaver.as_parsed(some_multi_int_flag=['123', '456']) + def some_func(): + assert FLAGS.some_multi_int_flag.present + do_stuff() + + # If a flag name includes non-identifier characters it can be specified like + # so: + @flagsaver.as_parsed(**{'i-like-dashes': 'true'}) + def some_func(): + do_stuff() + We save and restore a shallow copy of each Flag object's ``__dict__`` attribute. This preserves all attributes of the flag, such as whether or not it was overridden from its default value. @@ -58,14 +89,16 @@ def some_func(): and then restore flag values, the added flag will be deleted with no errors. """ +import collections import functools import inspect -from typing import overload, Any, Callable, Mapping, Tuple, TypeVar +from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union from absl import flags FLAGS = flags.FLAGS + # The type of pre/post wrapped functions. _CallableT = TypeVar('_CallableT', bound=Callable) @@ -83,8 +116,86 @@ def flagsaver(func: _CallableT) -> _CallableT: def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" + return _construct_overrider(_FlagOverrider, *args, **kwargs) + + +@overload +def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]], + **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider': + ... + + +@overload +def as_parsed(func: _CallableT) -> _CallableT: + ... + + +def as_parsed(*args, **kwargs): + """Overrides flags by parsing strings, saves flag state similar to flagsaver. + + This function can be used as either a decorator or context manager similar to + flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the + flags to new values, this function will parse the provided arguments as if + they were provided on the command line. Among other things, this will cause + `FLAGS['flag_name'].parsed == True`. + + A note on unparsed input: For many flag types, the unparsed version will be + a single string. However for multi_x (multi_string, multi_integer, multi_enum) + the unparsed version will be a Sequence of strings. + + Args: + *args: Tuples of FlagHolders and their unparsed value. + **kwargs: The keyword args are flag names, and the values are unparsed + values. + + Returns: + _ParsingFlagOverrider that serves as a context manager or decorator. Will + save previous flag state and parse new flags, then on cleanup it will + restore the previous flag state. + """ + return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs) + + +# NOTE: the order of these overload declarations matters. The type checker will +# pick the first match which could be incorrect. +@overload +def _construct_overrider( + flag_overrider_cls: Type['_ParsingFlagOverrider'], + *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]], + **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider': + ... + + +@overload +def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'], + *args: Tuple[flags.FlagHolder, Any], + **kwargs: Any) -> '_FlagOverrider': + ... + + +@overload +def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'], + func: _CallableT) -> _CallableT: + ... + + +def _construct_overrider(flag_overrider_cls, *args, **kwargs): + """Handles the args/kwargs returning an instance of flag_overrider_cls. + + If flag_overrider_cls is _FlagOverrider then values should be native python + types matching the python types. Otherwise if flag_overrider_cls is + _ParsingFlagOverrider the values should be strings or sequences of strings. + + Args: + flag_overrider_cls: The class that will do the overriding. + *args: Tuples of FlagHolder and the new flag value. + **kwargs: Keword args mapping flag name to new flag value. + + Returns: + A _FlagOverrider to be used as a decorator or context manager. + """ if not args: - return _FlagOverrider(**kwargs) + return flag_overrider_cls(**kwargs) # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` if len(args) == 1 and callable(args[0]): if kwargs: @@ -93,7 +204,7 @@ def flagsaver(*args, **kwargs): func = args[0] if inspect.isclass(func): raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') - return _wrap(func, {}) + return _wrap(flag_overrider_cls, func, {}) # args can be a list of (FlagHolder, value) pairs. # In which case they augment any specified kwargs. for arg in args: @@ -105,7 +216,7 @@ def flagsaver(*args, **kwargs): if holder.name in kwargs: raise ValueError('Cannot set --%s multiple times' % holder.name) kwargs[holder.name] = value - return _FlagOverrider(**kwargs) + return flag_overrider_cls(**kwargs) def save_flag_values( @@ -144,13 +255,27 @@ def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]], flag_values[name].__dict__ = saved -def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT: +@overload +def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT, + overrides: Mapping[str, Any]) -> _CallableT: + ... + + +@overload +def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT, + overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT: + ... + + +def _wrap(flag_overrider_cls, func, overrides): """Creates a wrapper function that saves/restores flag values. Args: + flag_overrider_cls: The class that will be used as a context manager. func: This will be called between saving flags and restoring flags. overrides: Flag names mapped to their values. These flags will be set after - saving the original flag state. + saving the original flag state. The type of the values depends on if + _FlagOverrider or _ParsingFlagOverrider was specified. Returns: A wrapped version of func. @@ -159,7 +284,7 @@ def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT: @functools.wraps(func) def _flagsaver_wrapper(*args, **kwargs): """Wrapper function that saves and restores flags.""" - with _FlagOverrider(**overrides): + with flag_overrider_cls(**overrides): return func(*args, **kwargs) return _flagsaver_wrapper @@ -179,7 +304,7 @@ def __init__(self, **overrides: Any): def __call__(self, func: _CallableT) -> _CallableT: if inspect.isclass(func): raise TypeError('flagsaver cannot be applied to a class.') - return _wrap(func, self._overrides) + return _wrap(self.__class__, func, self._overrides) def __enter__(self): self._saved_flag_values = save_flag_values(FLAGS) @@ -194,6 +319,55 @@ def __exit__(self, exc_type, exc_value, traceback): restore_flag_values(self._saved_flag_values, FLAGS) +class _ParsingFlagOverrider(_FlagOverrider): + """Context manager for overriding flags. + + Simulates command line parsing. + + This is simlar to _FlagOverrider except that all **overrides should be + strings or sequences of strings, and when context is entered this class calls + .parse(value) + + This results in the flags having .present set properly. + """ + + def __init__(self, **overrides: Union[str, Sequence[str]]): + for flag_name, new_value in overrides.items(): + if isinstance(new_value, str): + continue + if (isinstance(new_value, collections.abc.Sequence) and + all(isinstance(single_value, str) for single_value in new_value)): + continue + raise TypeError( + f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single ' + f'string or sequence of strings but {type(new_value)} was provided.') + super().__init__(**overrides) + + def __enter__(self): + self._saved_flag_values = save_flag_values(FLAGS) + try: + for flag_name, unparsed_value in self._overrides.items(): + # LINT.IfChange(flag_override_parsing) + FLAGS[flag_name].parse(unparsed_value) + FLAGS[flag_name].using_default_value = False + # LINT.ThenChange() + + # Perform the validation on all modified flags. This is something that + # FLAGS._set_attributes() does for you in _FlagOverrider. + for flag_name in self._overrides: + FLAGS._assert_validators(FLAGS[flag_name].validators) + + except KeyError as e: + # If a flag doesn't exist, an UnrecognizedFlagError is more specific. + restore_flag_values(self._saved_flag_values, FLAGS) + raise flags.UnrecognizedFlagError('Unknown command line flag.') from e + + except: + # It may fail because of flag validators or general parsing issues. + restore_flag_values(self._saved_flag_values, FLAGS) + raise + + def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]: """Returns a copy of the flag object's ``__dict__``. diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py index e98cd06f..b8f91a57 100644 --- a/absl/testing/tests/flagsaver_test.py +++ b/absl/testing/tests/flagsaver_test.py @@ -16,6 +16,7 @@ from absl import flags from absl.testing import absltest from absl.testing import flagsaver +from absl.testing import parameterized flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with') flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with') @@ -31,6 +32,9 @@ STR_FLAG = flags.DEFINE_string( 'flagsaver_test_str_flag', default='str default', help='help') +MULTI_INT_FLAG = flags.DEFINE_multi_integer('flagsaver_test_multi_int_flag', + None, 'flag to test with') + @flags.multi_flags_validator( ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) @@ -51,194 +55,83 @@ class _TestError(Exception): """Exception class for use in these tests.""" -class FlagSaverTest(absltest.TestCase): +class CommonUsageTest(absltest.TestCase): + """These test cases cover the most common usages of flagsaver.""" - def test_context_manager_without_parameters(self): - with flagsaver.flagsaver(): - FLAGS.flagsaver_test_flag0 = 'new value' - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - - def test_context_manager_with_overrides(self): - with flagsaver.flagsaver(flagsaver_test_flag0='new value'): - self.assertEqual('new value', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag1 = 'another value' + def test_as_parsed_context_manager(self): + # Precondition check, we expect all the flags to start as their default. + self.assertEqual('str default', STR_FLAG.value) + self.assertFalse(STR_FLAG.present) + self.assertEqual(1, INT_FLAG.value) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) - def test_context_manager_with_flagholders(self): - with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')): - self.assertEqual('new value', STR_FLAG.value) - self.assertEqual(3, INT_FLAG.value) - FLAGS.flagsaver_test_flag1 = 'another value' - self.assertEqual(INT_FLAG.value, INT_FLAG.default) - self.assertEqual(STR_FLAG.value, STR_FLAG.default) - self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) - - def test_context_manager_with_overrides_and_flagholders(self): - with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'): - self.assertEqual(STR_FLAG.default, STR_FLAG.value) - self.assertEqual(3, INT_FLAG.value) - FLAGS.flagsaver_test_flag0 = 'new value' - self.assertEqual(INT_FLAG.value, INT_FLAG.default) - self.assertEqual(STR_FLAG.value, STR_FLAG.default) - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - - def test_context_manager_with_cross_validated_overrides_set_together(self): - # When the flags are set in the same flagsaver call their validators will - # be triggered only once the setting is done. - with flagsaver.flagsaver( - flagsaver_test_validated_flag1='new_value', - flagsaver_test_validated_flag2='new_value'): - self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) - self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) - - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_context_manager_with_cross_validated_overrides_set_badly(self): - - # Different values should violate the validator. - with self.assertRaisesRegex(flags.IllegalFlagValueError, - 'Flag validation failed'): - with flagsaver.flagsaver( - flagsaver_test_validated_flag1='new_value', - flagsaver_test_validated_flag2='other_value'): - pass - - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_context_manager_with_cross_validated_overrides_set_separately(self): - - # Setting just one flag will trip the validator as well. - with self.assertRaisesRegex(flags.IllegalFlagValueError, - 'Flag validation failed'): - with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'): - pass - - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_context_manager_with_exception(self): - with self.assertRaises(_TestError): - with flagsaver.flagsaver(flagsaver_test_flag0='new value'): - self.assertEqual('new value', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag1 = 'another value' - raise _TestError('oops') - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) - - def test_context_manager_with_validation_exception(self): - with self.assertRaises(flags.IllegalFlagValueError): - with flagsaver.flagsaver( - flagsaver_test_flag0='new value', - flagsaver_test_validated_flag='new value'): - pass - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag) - - def test_decorator_without_call(self): - - @flagsaver.flagsaver - def mutate_flags(value): - """Test function that mutates a flag.""" - # The undecorated method mutates --flagsaver_test_flag0 to the given value - # and then returns the value of that flag. If the @flagsaver.flagsaver - # decorator works as designed, then this mutation will be reverted after - # this method returns. - FLAGS.flagsaver_test_flag0 = value - return FLAGS.flagsaver_test_flag0 - - # mutate_flags returns the flag value before it gets restored by - # the flagsaver decorator. So we check that flag value was - # actually changed in the method's scope. - self.assertEqual('new value', mutate_flags('new value')) - # But... notice that the flag is now unchanged0. - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - - def test_decorator_without_parameters(self): - - @flagsaver.flagsaver() - def mutate_flags(value): - FLAGS.flagsaver_test_flag0 = value - return FLAGS.flagsaver_test_flag0 - - self.assertEqual('new value', mutate_flags('new value')) - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - - def test_decorator_with_overrides(self): + # Flagsaver will also save the state of flags that have been modified. + FLAGS.flagsaver_test_flag1 = 'outside flagsaver' + + # Save all existing flag state, and set some flags as if they were parsed on + # the command line. Because of this, the new values must be provided as str, + # even if the flag type is something other than string. + with flagsaver.as_parsed( + (STR_FLAG, 'new string value'), # Override using flagholder object. + (INT_FLAG, '123'), # Override an int flag (NOTE: must specify as str). + flagsaver_test_flag0='new value', # Override using flag name. + ): + # All the flags have their overridden values. + self.assertEqual('new string value', STR_FLAG.value) + self.assertTrue(STR_FLAG.present) + self.assertEqual(123, INT_FLAG.value) + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + # Even if we change other flags, they will reset on context exit. + FLAGS.flagsaver_test_flag1 = 'new value 1' - @flagsaver.flagsaver(flagsaver_test_flag0='new value') - def mutate_flags(): - """Test function expecting new value.""" - # If the @flagsaver.decorator decorator works as designed, - # then the value of the flag should be changed in the scope of - # the method but the change will be reverted after this method - # returns. - return FLAGS.flagsaver_test_flag0 - - # mutate_flags returns the flag value before it gets restored by - # the flagsaver decorator. So we check that flag value was - # actually changed in the method's scope. - self.assertEqual('new value', mutate_flags()) - # But... notice that the flag is now unchanged0. + # The flags have all reset to their pre-flagsaver values. + self.assertEqual('str default', STR_FLAG.value) + self.assertFalse(STR_FLAG.present) + self.assertEqual(1, INT_FLAG.value) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - - def test_decorator_with_cross_validated_overrides_set_together(self): - - # When the flags are set in the same flagsaver call their validators will - # be triggered only once the setting is done. - @flagsaver.flagsaver( - flagsaver_test_validated_flag1='new_value', - flagsaver_test_validated_flag2='new_value') - def mutate_flags_together(): - return (FLAGS.flagsaver_test_validated_flag1, - FLAGS.flagsaver_test_validated_flag2) - - self.assertEqual(('new_value', 'new_value'), mutate_flags_together()) - - # The flags have not changed outside the context of the function. - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_decorator_with_cross_validated_overrides_set_badly(self): - - # Different values should violate the validator. - @flagsaver.flagsaver( - flagsaver_test_validated_flag1='new_value', - flagsaver_test_validated_flag2='other_value') - def mutate_flags_together_badly(): - return (FLAGS.flagsaver_test_validated_flag1, - FLAGS.flagsaver_test_validated_flag2) - - with self.assertRaisesRegex(flags.IllegalFlagValueError, - 'Flag validation failed'): - mutate_flags_together_badly() - - # The flags have not changed outside the context of the exception. - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_decorator_with_cross_validated_overrides_set_separately(self): - - # Setting the flags sequentially and not together will trip the validator, - # because it will be called at the end of each flagsaver call. - @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value') - @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value') - def mutate_flags_separately(): - return (FLAGS.flagsaver_test_validated_flag1, - FLAGS.flagsaver_test_validated_flag2) - - with self.assertRaisesRegex(flags.IllegalFlagValueError, - 'Flag validation failed'): - mutate_flags_separately() - - # The flags have not changed outside the context of the exception. - self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) - self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) - - def test_save_flag_value(self): + self.assertEqual('outside flagsaver', FLAGS.flagsaver_test_flag1) + + def test_as_parsed_decorator(self): + # flagsaver.as_parsed can also be used as a decorator. + @flagsaver.as_parsed((INT_FLAG, '123')) + def do_something_with_flags(): + self.assertEqual(123, INT_FLAG.value) + self.assertTrue(INT_FLAG.present) + + do_something_with_flags() + self.assertEqual(1, INT_FLAG.value) + self.assertFalse(INT_FLAG.present) + + def test_flagsaver_flagsaver(self): + # If you don't want the flags to go through parsing, you can instead use + # flagsaver.flagsaver(). With this method, you provide the native python + # value you'd like the flags to take on. Otherwise it functions similar to + # flagsaver.as_parsed(). + @flagsaver.flagsaver((INT_FLAG, 345)) + def do_something_with_flags(): + self.assertEqual(345, INT_FLAG.value) + # Note that because this flag was never parsed, it will not register as + # .present unless you manually set that attribute. + self.assertFalse(INT_FLAG.present) + # If you do chose to modify things about the flag (such as .present) those + # changes will still be cleaned up when flagsaver.flagsaver() exits. + INT_FLAG.present = True + + self.assertEqual(1, INT_FLAG.value) + # flagsaver.flagsaver() restored INT_FLAG.present to the state it was in + # before entering the context. + self.assertFalse(INT_FLAG.present) + + +class SaveFlagValuesTest(absltest.TestCase): + """Test flagsaver.save_flag_values() and flagsaver.restore_flag_values(). + + In this test, we insure that *all* properties of flags get restored. In other + tests we only try changing the flag value. + """ + + def test_assign_value(self): # First save the flag values. saved_flag_values = flagsaver.save_flag_values() @@ -250,7 +143,7 @@ def test_save_flag_value(self): flagsaver.restore_flag_values(saved_flag_values) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - def test_save_flag_default(self): + def test_set_default(self): # First save the flag. saved_flag_values = flagsaver.save_flag_values() @@ -262,7 +155,7 @@ def test_save_flag_default(self): flagsaver.restore_flag_values(saved_flag_values) self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default) - def test_restore_after_parse(self): + def test_parse(self): # First save the flag. saved_flag_values = flagsaver.save_flag_values() @@ -278,9 +171,72 @@ def test_restore_after_parse(self): self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value) self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present) - def test_decorator_with_exception(self): + def test_assign_validators(self): + # First save the flag. + saved_flag_values = flagsaver.save_flag_values() + + # Sanity check that a validator already exists. + self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1) + original_validators = list(FLAGS['flagsaver_test_flag0'].validators) + + def no_space(value): + return ' ' not in value + + # Add a new validator. + flags.register_validator('flagsaver_test_flag0', no_space) + self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2) + + # Now restore the flag to its original value. + flagsaver.restore_flag_values(saved_flag_values) + self.assertEqual( + original_validators, FLAGS['flagsaver_test_flag0'].validators + ) + + +@parameterized.named_parameters( + dict( + testcase_name='flagsaver.flagsaver', + flagsaver_method=flagsaver.flagsaver, + ), + dict( + testcase_name='flagsaver.as_parsed', + flagsaver_method=flagsaver.as_parsed, + ), +) +class NoOverridesTest(parameterized.TestCase): + """Test flagsaver.flagsaver and flagsaver.as_parsed without overrides.""" + + def test_context_manager_with_call(self, flagsaver_method): + with flagsaver_method(): + FLAGS.flagsaver_test_flag0 = 'new value' + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_context_manager_with_exception(self, flagsaver_method): + with self.assertRaises(_TestError): + with flagsaver_method(): + FLAGS.flagsaver_test_flag0 = 'new value' + # Simulate a failed test. + raise _TestError('something happened') + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_without_call(self, flagsaver_method): + @flagsaver_method + def mutate_flags(): + FLAGS.flagsaver_test_flag0 = 'new value' + + mutate_flags() + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_with_call(self, flagsaver_method): + @flagsaver_method() + def mutate_flags(): + FLAGS.flagsaver_test_flag0 = 'new value' - @flagsaver.flagsaver + mutate_flags() + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_with_exception(self, flagsaver_method): + @flagsaver_method() def raise_exception(): FLAGS.flagsaver_test_flag0 = 'new value' # Simulate a failed test. @@ -290,62 +246,262 @@ def raise_exception(): self.assertRaises(_TestError, raise_exception) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - def test_validator_list_is_restored(self): - self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1) - original_validators = list(FLAGS['flagsaver_test_flag0'].validators) +@parameterized.named_parameters( + dict( + testcase_name='flagsaver.flagsaver', + flagsaver_method=flagsaver.flagsaver, + ), + dict( + testcase_name='flagsaver.as_parsed', + flagsaver_method=flagsaver.as_parsed, + ), +) +class TestStringFlagOverrides(parameterized.TestCase): + """Test flagsaver.flagsaver and flagsaver.as_parsed with string overrides. + + Note that these tests can be parameterized because both .flagsaver and + .as_parsed expect a str input when overriding a string flag. For non-string + flags these two flagsaver methods have separate tests elsewhere in this file. + + Each test is one class of overrides, executed twice. Once as a context + manager, and once as a decorator on a mutate_flags() method. + """ + + def test_keyword_overrides(self, flagsaver_method): + # Context manager: + with flagsaver_method(flagsaver_test_flag0='new value'): + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - @flagsaver.flagsaver - def modify_validators(): + # Decorator: + @flagsaver_method(flagsaver_test_flag0='new value') + def mutate_flags(): + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) - def no_space(value): - return ' ' not in value + mutate_flags() + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - flags.register_validator('flagsaver_test_flag0', no_space) - self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2) + def test_flagholder_overrides(self, flagsaver_method): + with flagsaver_method((STR_FLAG, 'new value')): + self.assertEqual('new value', STR_FLAG.value) + self.assertEqual('str default', STR_FLAG.value) - modify_validators() - self.assertEqual(original_validators, - FLAGS['flagsaver_test_flag0'].validators) + @flagsaver_method((STR_FLAG, 'new value')) + def mutate_flags(): + self.assertEqual('new value', STR_FLAG.value) + mutate_flags() + self.assertEqual('str default', STR_FLAG.value) -class FlagSaverDecoratorUsageTest(absltest.TestCase): + def test_keyword_and_flagholder_overrides(self, flagsaver_method): + with flagsaver_method( + (STR_FLAG, 'another value'), flagsaver_test_flag0='new value' + ): + self.assertEqual('another value', STR_FLAG.value) + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + self.assertEqual('str default', STR_FLAG.value) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - @flagsaver.flagsaver - def test_mutate1(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. + @flagsaver_method( + (STR_FLAG, 'another value'), flagsaver_test_flag0='new value' + ) + def mutate_flags(): + self.assertEqual('another value', STR_FLAG.value) + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + + mutate_flags() + self.assertEqual('str default', STR_FLAG.value) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' - @flagsaver.flagsaver - def test_mutate2(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. + def test_cross_validated_overrides_set_together(self, flagsaver_method): + # When the flags are set in the same flagsaver call their validators will + # be triggered only once the setting is done. + with flagsaver_method( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value', + ): + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + @flagsaver_method( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value', + ) + def mutate_flags(): + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) + + mutate_flags() + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_cross_validated_overrides_set_badly(self, flagsaver_method): + # Different values should violate the validator. + with self.assertRaisesRegex( + flags.IllegalFlagValueError, 'Flag validation failed' + ): + with flagsaver_method( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value', + ): + pass + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + @flagsaver_method( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value', + ) + def mutate_flags(): + pass + + self.assertRaisesRegex( + flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags + ) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_cross_validated_overrides_set_separately(self, flagsaver_method): + # Setting just one flag will trip the validator as well. + with self.assertRaisesRegex( + flags.IllegalFlagValueError, 'Flag validation failed' + ): + with flagsaver_method(flagsaver_test_validated_flag1='new_value'): + pass + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + @flagsaver_method(flagsaver_test_validated_flag1='new_value') + def mutate_flags(): + pass + + self.assertRaisesRegex( + flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags + ) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_validation_exception(self, flagsaver_method): + with self.assertRaises(flags.IllegalFlagValueError): + with flagsaver_method( + flagsaver_test_flag0='new value', + flagsaver_test_validated_flag='new value', + ): + pass self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' + self.assertIsNone(FLAGS.flagsaver_test_validated_flag) - @flagsaver.flagsaver - def test_mutate3(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. + @flagsaver_method( + flagsaver_test_flag0='new value', + flagsaver_test_validated_flag='new value', + ) + def mutate_flags(): + pass + + self.assertRaises(flags.IllegalFlagValueError, mutate_flags) self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' + self.assertIsNone(FLAGS.flagsaver_test_validated_flag) - @flagsaver.flagsaver - def test_mutate4(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. + def test_unknown_flag_raises_exception(self, flagsaver_method): + self.assertNotIn('this_flag_does_not_exist', FLAGS) + + # Flagsaver raises an error when trying to override a non-existent flag. + with self.assertRaises(flags.UnrecognizedFlagError): + with flagsaver_method( + flagsaver_test_flag0='new value', this_flag_does_not_exist='new value' + ): + pass self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' + @flagsaver_method( + flagsaver_test_flag0='new value', this_flag_does_not_exist='new value' + ) + def mutate_flags(): + pass + + self.assertRaises(flags.UnrecognizedFlagError, mutate_flags) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) -class FlagSaverSetUpTearDownUsageTest(absltest.TestCase): + # Make sure flagsaver didn't create the flag at any point. + self.assertNotIn('this_flag_does_not_exist', FLAGS) + + +class AsParsedTest(absltest.TestCase): + + def test_parse_context_manager_sets_present_and_using_default(self): + self.assertFalse(INT_FLAG.present) + self.assertFalse(STR_FLAG.present) + # Note that .using_default_value isn't available on the FlagHolder directly. + self.assertTrue(FLAGS[INT_FLAG.name].using_default_value) + self.assertTrue(FLAGS[STR_FLAG.name].using_default_value) + + with flagsaver.as_parsed((INT_FLAG, '123'), + flagsaver_test_str_flag='new value'): + self.assertTrue(INT_FLAG.present) + self.assertTrue(STR_FLAG.present) + self.assertFalse(FLAGS[INT_FLAG.name].using_default_value) + self.assertFalse(FLAGS[STR_FLAG.name].using_default_value) + + self.assertFalse(INT_FLAG.present) + self.assertFalse(STR_FLAG.present) + self.assertTrue(FLAGS[INT_FLAG.name].using_default_value) + self.assertTrue(FLAGS[STR_FLAG.name].using_default_value) + + def test_parse_decorator_sets_present_and_using_default(self): + self.assertFalse(INT_FLAG.present) + self.assertFalse(STR_FLAG.present) + # Note that .using_default_value isn't available on the FlagHolder directly. + self.assertTrue(FLAGS[INT_FLAG.name].using_default_value) + self.assertTrue(FLAGS[STR_FLAG.name].using_default_value) + + @flagsaver.as_parsed((INT_FLAG, '123'), flagsaver_test_str_flag='new value') + def some_func(): + self.assertTrue(INT_FLAG.present) + self.assertTrue(STR_FLAG.present) + self.assertFalse(FLAGS[INT_FLAG.name].using_default_value) + self.assertFalse(FLAGS[STR_FLAG.name].using_default_value) + + some_func() + self.assertFalse(INT_FLAG.present) + self.assertFalse(STR_FLAG.present) + self.assertTrue(FLAGS[INT_FLAG.name].using_default_value) + self.assertTrue(FLAGS[STR_FLAG.name].using_default_value) + + def test_parse_decorator_with_multi_int_flag(self): + self.assertFalse(MULTI_INT_FLAG.present) + self.assertIsNone(MULTI_INT_FLAG.value) + + @flagsaver.as_parsed((MULTI_INT_FLAG, ['123', '456'])) + def assert_flags_updated(): + self.assertTrue(MULTI_INT_FLAG.present) + self.assertCountEqual([123, 456], MULTI_INT_FLAG.value) + + assert_flags_updated() + self.assertFalse(MULTI_INT_FLAG.present) + self.assertIsNone(MULTI_INT_FLAG.value) + + def test_parse_raises_type_error(self): + with self.assertRaisesRegex( + TypeError, + r'flagsaver\.as_parsed\(\) cannot parse flagsaver_test_int_flag\. ' + r'Expected a single string or sequence of strings but .*int.* was ' + r'provided\.'): + manager = flagsaver.as_parsed(flagsaver_test_int_flag=123) + del manager + + +class SetUpTearDownTest(absltest.TestCase): + """Example using a single flagsaver in setUp.""" def setUp(self): + super().setUp() self.saved_flag_values = flagsaver.save_flag_values() def tearDown(self): + super().tearDown() flagsaver.restore_flag_values(self.saved_flag_values) def test_mutate1(self): @@ -360,28 +516,26 @@ def test_mutate2(self): self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) FLAGS.flagsaver_test_flag0 = 'changed0' - def test_mutate3(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' - - def test_mutate4(self): - # Even though other test cases change the flag, it should be - # restored to 'unchanged0' if the flagsaver is working. - self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) - FLAGS.flagsaver_test_flag0 = 'changed0' - -class FlagSaverBadUsageTest(absltest.TestCase): - """Tests that certain kinds of improper usages raise errors.""" - - def test_flag_saver_on_class(self): +@parameterized.named_parameters( + dict( + testcase_name='flagsaver.flagsaver', + flagsaver_method=flagsaver.flagsaver, + ), + dict( + testcase_name='flagsaver.as_parsed', + flagsaver_method=flagsaver.as_parsed, + ), +) +class BadUsageTest(parameterized.TestCase): + """Tests that improper usage (such as decorating a class) raise errors.""" + + def test_flag_saver_on_class(self, flagsaver_method): with self.assertRaises(TypeError): # WRONG. Don't do this. # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. - @flagsaver.flagsaver + @flagsaver_method class FooTest(absltest.TestCase): def test_tautology(self): @@ -389,12 +543,12 @@ def test_tautology(self): del FooTest - def test_flag_saver_call_on_class(self): + def test_flag_saver_call_on_class(self, flagsaver_method): with self.assertRaises(TypeError): # WRONG. Don't do this. # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. - @flagsaver.flagsaver() + @flagsaver_method() class FooTest(absltest.TestCase): def test_tautology(self): @@ -402,12 +556,12 @@ def test_tautology(self): del FooTest - def test_flag_saver_with_overrides_on_class(self): + def test_flag_saver_with_overrides_on_class(self, flagsaver_method): with self.assertRaises(TypeError): # WRONG. Don't do this. # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. - @flagsaver.flagsaver(foo='bar') + @flagsaver_method(foo='bar') class FooTest(absltest.TestCase): def test_tautology(self): @@ -415,48 +569,57 @@ def test_tautology(self): del FooTest - def test_multiple_positional_parameters(self): + def test_multiple_positional_parameters(self, flagsaver_method): with self.assertRaises(ValueError): func_a = lambda: None func_b = lambda: None - flagsaver.flagsaver(func_a, func_b) + flagsaver_method(func_a, func_b) - def test_both_positional_and_keyword_parameters(self): + def test_both_positional_and_keyword_parameters(self, flagsaver_method): with self.assertRaises(ValueError): func_a = lambda: None - flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value') + flagsaver_method(func_a, flagsaver_test_flag0='new value') - def test_duplicate_holder_parameters(self): + def test_duplicate_holder_parameters(self, flagsaver_method): with self.assertRaises(ValueError): - flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45)) + flagsaver_method((INT_FLAG, 45), (INT_FLAG, 45)) - def test_duplicate_holder_and_kw_parameter(self): + def test_duplicate_holder_and_kw_parameter(self, flagsaver_method): with self.assertRaises(ValueError): - flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45}) + flagsaver_method((INT_FLAG, 45), **{INT_FLAG.name: 45}) - def test_both_positional_and_holder_parameters(self): + def test_both_positional_and_holder_parameters(self, flagsaver_method): with self.assertRaises(ValueError): func_a = lambda: None - flagsaver.flagsaver(func_a, (INT_FLAG, 45)) + flagsaver_method(func_a, (INT_FLAG, 45)) - def test_holder_parameters_wrong_shape(self): + def test_holder_parameters_wrong_shape(self, flagsaver_method): with self.assertRaises(ValueError): - flagsaver.flagsaver(INT_FLAG) + flagsaver_method(INT_FLAG) - def test_holder_parameters_tuple_too_long(self): + def test_holder_parameters_tuple_too_long(self, flagsaver_method): with self.assertRaises(ValueError): # Even if it is a bool flag, it should be a tuple - flagsaver.flagsaver((INT_FLAG, 4, 5)) + flagsaver_method((INT_FLAG, 4, 5)) - def test_holder_parameters_tuple_wrong_type(self): + def test_holder_parameters_tuple_wrong_type(self, flagsaver_method): with self.assertRaises(ValueError): # Even if it is a bool flag, it should be a tuple - flagsaver.flagsaver((4, INT_FLAG)) + flagsaver_method((4, INT_FLAG)) - def test_both_wrong_positional_parameters(self): + def test_both_wrong_positional_parameters(self, flagsaver_method): with self.assertRaises(ValueError): func_a = lambda: None - flagsaver.flagsaver(func_a, STR_FLAG, '45') + flagsaver_method(func_a, STR_FLAG, '45') + + def test_context_manager_no_call(self, flagsaver_method): + # The exact exception that's raised appears to be system specific. + with self.assertRaises((AttributeError, TypeError)): + # Wrong. You must call the flagsaver method before using it as a CM. + with flagsaver_method: + # We don't expect to get here. A type error should happen when + # attempting to enter the context manager. + pass if __name__ == '__main__':