Skip to content

Commit

Permalink
Merge pull request #29 from NREL/pp/cyclic_reference_fix
Browse files Browse the repository at this point in the history
Added error checking for self-referential equations
  • Loading branch information
grantbuster committed Feb 25, 2022
2 parents 1d0a4f4 + a4d3145 commit c64bd85
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 8 deletions.
48 changes: 48 additions & 0 deletions NRWAL/handlers/config.py
Expand Up @@ -265,6 +265,10 @@ def __init__(self, config, inputs=None, interp_extrap_power=False,
self._eqn_dir = EquationDirectory(eqn_dir, **kwargs)
self._global_variables = self._parse_global_variables(config)
self._raw_config = copy.deepcopy(config)

for name, expression in config.items():
self._check_circ_ref(name, expression, config)

self._config = self._parse_config(config, self._eqn_dir,
self._global_variables)

Expand Down Expand Up @@ -382,6 +386,50 @@ def _parse_global_variables(config):

return gvars

@classmethod
def _check_circ_ref(cls, orig_name, expression, config, current_name=None,
msg=None):
"""Check the config for circular variable references that would result
in a recursion error.
Parameters
----------
orig_name : str
The starting equation name to check for circular references.
expression : str
A string entry in the config, can be a number, an EquationDirectory
retrieval string, a key referencing a config entry, or a
mathematical expression combining these options.
config : dict
NRWAL config dictionary mapping names (str) to expressions (str)
current_name : str
The current equation name in the recursive search.
msg : str
The error message to be printed if a circular reference is found.
"""

if current_name is None:
current_name = orig_name

if msg is None:
msg = ('Found a circular reference with NRWAL equations: {}'
.format(orig_name))
else:
msg += ' -> {}'.format(current_name)

all_vars = Equation.parse_variables(expression)

if orig_name in all_vars:
msg += (', and ending with expression "{}": {}'
.format(current_name, expression))
logger.error(msg)
raise RuntimeError(msg)

for var in all_vars:
if var in config:
cls._check_circ_ref(orig_name, config[var], config,
current_name=var, msg=msg)

@classmethod
def _parse_config(cls, config, eqn_dir, gvars):
"""Parse a config mapping of names-to-string-expressions into a
Expand Down
43 changes: 43 additions & 0 deletions NRWAL/handlers/equations.py
Expand Up @@ -49,6 +49,24 @@ def _preflight(self):
logger.error(msg)
raise ValueError(msg)

self.verify_no_self_reference()

def verify_no_self_reference(self):
"""Verify that the equation does not reference itself.
Raises
------
ValueError
If a reference to the equation name is found in its variables.
"""
if self._base_name in self.variables:
msg = ("Self-referencing is not allowed! Please change "
"either the equation name or the name of the dependent "
"variable in the following input equation: {} = {}"
.format(self._base_name, self._eqn))
logger.error(msg)
raise ValueError(msg)

@staticmethod
def _check_input_args(kwargs):
"""Check that input args to equation are of expected types."""
Expand All @@ -70,6 +88,31 @@ def _check_input_args(kwargs):

return kwargs

def replace_equation(self, new_eqn):
"""Replace the expression of this equation with a new one.
This method returns a new `Equation` instance that replaces
the existing equation expression with the new one supplied by the
user, keeping the equation name and default variables unchanged.
Parameters
----------
new_eqn : str
String representation of the new `Equation` instance.
Returns
-------
`Equation`
A new `Equation` instance with the same name and
default values as the old `Equation` but with the new
equation expression.
"""

return self.__class__(
new_eqn, name=self._base_name,
default_variables=self.default_variables
)

def __eqn_math(self, other, operator):
"""Perform arithmetic with this instance of Equation (self) and an
input "other" Equation and return a new Equation object that evaluates
Expand Down
9 changes: 6 additions & 3 deletions NRWAL/handlers/groups.py
Expand Up @@ -747,10 +747,13 @@ def _parse_group(self, group):
working = True
while working:
working = False
for eqn in [v for v in group.values() if isinstance(v, Equation)]:
for group_key, eqn in group.items():
if not isinstance(eqn, Equation):
continue
for var in [v for v in eqn.variables if v in group]:
repl_str = '({})'.format(group[var]._eqn)
eqn._eqn = eqn._eqn.replace(var, repl_str)
repl_str = '({})'.format(group[var].full)
new_eqn = eqn.full.replace(var, repl_str)
group[group_key] = eqn = eqn.replace_equation(new_eqn)
working = True

return group
Expand Down
2 changes: 1 addition & 1 deletion NRWAL/version.py
@@ -1,3 +1,3 @@
"""NRWAL Version number"""

__version__ = "0.0.7"
__version__ = "0.0.8"
14 changes: 14 additions & 0 deletions tests/data/bad_eqn_dir/bad_deep_self_ref_eqn.yaml
@@ -0,0 +1,14 @@
a_constant:
10

other_input:
input_x

other_input_2:
other_input

renamed_input_x:
a_constant * other_input_2

input_x:
renamed_input_x * 1000
2 changes: 2 additions & 0 deletions tests/data/bad_eqn_dir/bad_self_ref_eqn.yaml
@@ -0,0 +1,2 @@
bad_eqn:
bad_eqn * 1000
6 changes: 6 additions & 0 deletions tests/data/test_configs/test_config_bad_3.yaml
@@ -0,0 +1,6 @@
gcf:
cf_mean

# circular ref
cf_mean:
gcf * 1000
9 changes: 9 additions & 0 deletions tests/data/test_configs/test_config_bad_4.yaml
@@ -0,0 +1,9 @@
gcf:
cf_mean

# deeper circular ref
cf_mean:
something_else

something_else:
gcf * 1000
12 changes: 12 additions & 0 deletions tests/test_config.py → tests/test_handlers_config.py
Expand Up @@ -18,6 +18,9 @@
FP_BAD_0 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_bad_0.yml')
FP_BAD_1 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_bad_1.yml')
FP_BAD_2 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_bad_2.yaml')
FP_BAD_3 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_bad_3.yaml')
FP_BAD_4 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_bad_4.yaml')

FP_GOOD_0 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_good_0.yml')
FP_GOOD_1 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_good_1.yml')
FP_GOOD_2 = os.path.join(TEST_DATA_DIR, 'test_configs/test_config_good_2.yml')
Expand Down Expand Up @@ -266,6 +269,15 @@ def test_config_reference_bad():
_ = NrwalConfig(FP_BAD_2)


def test_bad_circular():
"""Test that NRWAL raises an error for circular equation references"""
with pytest.raises(RuntimeError):
NrwalConfig(FP_BAD_3)

with pytest.raises(RuntimeError):
NrwalConfig(FP_BAD_4)


def test_leading_negative():
"""Test config with equations that have leading negative signs or
negative signs attached to variables"""
Expand Down
20 changes: 16 additions & 4 deletions tests/test_handlers_equations.py
Expand Up @@ -33,16 +33,16 @@ def test_print_eqn():
known_vars = ('depth', 'outfitting_cost')
eqn = obj[eqn_name]
assert len(eqn.variables) == len(known_vars)
assert all([v in eqn.variables for v in known_vars])
assert all([v in str(eqn) for v in known_vars])
assert all(v in eqn.variables for v in known_vars)
assert all(v in str(eqn) for v in known_vars)
assert eqn_name in str(eqn)

eqn_name = 'lattice'
known_vars = ('turbine_capacity', 'depth', 'lattice_cost')
eqn = obj[eqn_name]
assert len(eqn.variables) == len(known_vars)
assert all([v in eqn.variables for v in known_vars])
assert all([v in str(eqn) for v in known_vars])
assert all(v in eqn.variables for v in known_vars)
assert all(v in str(eqn) for v in known_vars)
assert eqn_name in str(eqn)

eqn = obj['subgroup::eqn1']
Expand Down Expand Up @@ -167,3 +167,15 @@ def test_numpy_eqns():
fp *= 2
truth = 2 + 10 * 2 * np.array([0.25, 0.5, 1])
assert np.allclose(eqn.eval(x=x, xp=xp, fp=fp), truth)


@pytest.mark.parametrize(
'bad_fp', ('bad_self_ref_eqn.yaml', 'bad_deep_self_ref_eqn.yaml')
)
def test_self_referential_eq(bad_fp):
"""Test self-referential equations. """
fp = os.path.join(BAD_DIR, bad_fp)
with pytest.raises(ValueError) as excinfo:
EquationGroup(fp)

assert "Self-referencing is not allowed!" in str(excinfo.value)

0 comments on commit c64bd85

Please sign in to comment.