diff --git a/pyomo/repn/linear.py b/pyomo/repn/linear.py index 2c47e118a2c..dd5b44cd58a 100644 --- a/pyomo/repn/linear.py +++ b/pyomo/repn/linear.py @@ -655,15 +655,36 @@ def __init__(self): self[SumExpression] = self._before_general_expression @staticmethod - def _before_var(visitor, child): - _id = id(child) + def record_monomial(visitor, result, coef, var): + _id = id(var) if _id not in visitor.var_map: - if child.fixed: - return False, (_CONSTANT, check_constant(child.value, child, visitor)) - visitor.var_recorder.add(child) + if var.fixed: + var = check_constant(var.value, var, visitor) + if not coef and var.__class__ is InvalidNumber: + deprecation_warning( + f"Encountered {coef}*{val2str(var)} in expression " + "tree. Mapping the NaN result to 0 for compatibility " + "with the lp_v1 writer. In the future, this NaN " + "will be preserved/emitted to comply with IEEE-754.", + version='6.6.0', + ) + else: + result.constant += coef * var + return + visitor.var_recorder.add(var) + if _id in result.linear: + result.linear[_id] += coef + elif coef: + result.linear[_id] = coef + + @staticmethod + def _before_var(visitor, child): ans = visitor.Result() - ans.linear[_id] = 1 - return False, (_LINEAR, ans) + visitor.record_monomial(visitor, ans, 1, child) + if ans.linear: + return False, (_LINEAR, ans) + else: + return False, (_CONSTANT, ans.constant) @staticmethod def _before_monomial(visitor, child): @@ -678,104 +699,45 @@ def _before_monomial(visitor, child): except (ValueError, ArithmeticError): return True, None - # We want to check / update the var_map before processing "0" - # coefficients so that we are consistent with what gets added to the - # var_map (e.g., 0*x*y: y is processed by _before_var and will - # always be added, but x is processed here) - _id = id(arg2) - if _id not in visitor.var_map: - if arg2.fixed: - return False, ( - _CONSTANT, - arg1 * check_constant(arg2.value, arg2, visitor), - ) - visitor.var_recorder.add(arg2) - - # Trap multiplication by 0 and nan. Note that arg1 was reduced - # to a numeric value at the beginning of this method. - if not arg1: - if arg2.fixed: - arg2 = check_constant(arg2.value, arg2, visitor) - if arg2.__class__ is InvalidNumber: - deprecation_warning( - f"Encountered {arg1}*{val2str(arg2)} in expression " - "tree. Mapping the NaN result to 0 for compatibility " - "with the lp_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.6.0', - ) - return False, (_CONSTANT, arg1) - ans = visitor.Result() - ans.linear[_id] = arg1 - return False, (_LINEAR, ans) + visitor.record_monomial(visitor, ans, arg1, arg2), + if ans.linear: + return False, (_LINEAR, ans) + else: + return False, (_CONSTANT, ans.constant) @staticmethod def _before_linear(visitor, child): - var_map = visitor.var_map ans = visitor.Result() + recorder = visitor.record_monomial + evaluate = visitor.evaluate const = 0 - linear = ans.linear for arg in child.args: if arg.__class__ is MonomialTermExpression: arg1, arg2 = arg._args_ if arg1.__class__ not in native_types: try: - arg1 = check_constant(visitor.evaluate(arg1), arg1, visitor) + arg1 = check_constant(evaluate(arg1), arg1, visitor) except (ValueError, ArithmeticError): return True, None - # Trap multiplication by 0 and nan. Note that arg1 was - # reduced to a numeric value at the beginning of this - # method. - if not arg1: - if arg2.fixed: - arg2 = check_constant(arg2.value, arg2, visitor) - if arg2.__class__ is InvalidNumber: - deprecation_warning( - f"Encountered {arg1}*{val2str(arg2)} in expression " - "tree. Mapping the NaN result to 0 for compatibility " - "with the lp_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.6.0', - ) - continue - - _id = id(arg2) - if _id not in var_map: - if arg2.fixed: - const += arg1 * check_constant(arg2.value, arg2, visitor) - continue - visitor.var_recorder.add(arg2) - linear[_id] = arg1 - elif _id in linear: - linear[_id] += arg1 - else: - linear[_id] = arg1 + recorder(visitor, ans, arg1, arg2) elif arg.__class__ in native_numeric_types: const += arg elif arg.is_variable_type(): - _id = id(arg) - if _id not in var_map: - if arg.fixed: - const += check_constant(arg.value, arg, visitor) - continue - visitor.var_recorder.add(arg) - linear[_id] = 1 - elif _id in linear: - linear[_id] += 1 - else: - linear[_id] = 1 + recorder(visitor, ans, 1, arg) else: + # Fixed objects; e.g. Params & Param expressions try: - const += check_constant(visitor.evaluate(arg), arg, visitor) + const += check_constant(evaluate(arg), arg, visitor) except (ValueError, ArithmeticError): return True, None - if linear: - ans.constant = const + + ans.constant += const + if ans.linear: return False, (_LINEAR, ans) else: - return False, (_CONSTANT, const) + return False, (_CONSTANT, ans.constant) @staticmethod def _before_named_expression(visitor, child): @@ -822,14 +784,15 @@ def __init__( super().__init__() self.subexpression_cache = subexpression_cache if any(_ is not None for _ in (var_map, var_order, sorter)): + _name = self.__class__.__name__ if var_recorder is not None: raise ValueError( - "LinearRepnVisitor: cannot specify any of var_map, " + f"{_name}: cannot specify any of var_map, " "var_order, or sorter with var_recorder" ) deprecation_warning( "var_map, var_order, and sorter are deprecated arguments to " - "LinearRepnVisitor(). Please pass the VarRecorder object directly.", + f"{_name}(). Please pass the VarRecorder object directly.", version='6.8.1', ) var_recorder = OrderedVarRecorder(var_map, var_order, sorter) @@ -839,6 +802,7 @@ def __init__( ) self.var_recorder = var_recorder self.var_map = var_recorder.var_map + self.record_monomial = self.before_child_dispatcher.record_monomial self._eval_expr_visitor = _EvaluationVisitor(True) self.evaluate = self._eval_expr_visitor.dfs_postorder_stack diff --git a/pyomo/repn/tests/test_linear.py b/pyomo/repn/tests/test_linear.py index ebaa3ee8001..269d1fc73c1 100644 --- a/pyomo/repn/tests/test_linear.py +++ b/pyomo/repn/tests/test_linear.py @@ -572,13 +572,18 @@ def test_monomial(self): cfg = VisitorConfig() with LoggingIntercept() as LOG: repn = LinearRepnVisitor(**cfg).walk_expression(param_expr) - self.assertEqual(LOG.getvalue(), "") + self.assertRegex( + LOG.getvalue(), + r"DEPRECATED: Encountered 0\*InvalidNumber\(nan\) in expression tree.", + ) self.assertEqual(cfg.subexpr, {}) self.assertEqual(cfg.var_map, {}) self.assertEqual(cfg.var_order, {}) self.assertEqual(repn.multiplier, 1) - self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan)) + # This will be true after we remove the deprecation warning: + # self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan)) + self.assertEqual(repn.constant, 0) self.assertEqual(repn.linear, {}) self.assertEqual(repn.nonlinear, None)