Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 48 additions & 84 deletions pyomo/repn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,15 +655,36 @@ def __init__(self):
self[SumExpression] = self._before_general_expression

@staticmethod
def _before_var(visitor, child):
_id = id(child)
def record_monomial(visitor, result, coef, var):
_id = id(var)
if _id not in visitor.var_map:
if child.fixed:
return False, (_CONSTANT, check_constant(child.value, child, visitor))
visitor.var_recorder.add(child)
if var.fixed:
var = check_constant(var.value, var, visitor)
if not coef and var.__class__ is InvalidNumber:
deprecation_warning(
f"Encountered {coef}*{val2str(var)} in expression "
"tree. Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
else:
result.constant += coef * var
return
visitor.var_recorder.add(var)
if _id in result.linear:
result.linear[_id] += coef
elif coef:
result.linear[_id] = coef

@staticmethod
def _before_var(visitor, child):
ans = visitor.Result()
ans.linear[_id] = 1
return False, (_LINEAR, ans)
visitor.record_monomial(visitor, ans, 1, child)
if ans.linear:
return False, (_LINEAR, ans)
else:
return False, (_CONSTANT, ans.constant)

@staticmethod
def _before_monomial(visitor, child):
Expand All @@ -678,104 +699,45 @@ def _before_monomial(visitor, child):
except (ValueError, ArithmeticError):
return True, None

# We want to check / update the var_map before processing "0"
# coefficients so that we are consistent with what gets added to the
# var_map (e.g., 0*x*y: y is processed by _before_var and will
# always be added, but x is processed here)
_id = id(arg2)
if _id not in visitor.var_map:
if arg2.fixed:
return False, (
_CONSTANT,
arg1 * check_constant(arg2.value, arg2, visitor),
)
visitor.var_recorder.add(arg2)

# Trap multiplication by 0 and nan. Note that arg1 was reduced
# to a numeric value at the beginning of this method.
if not arg1:
if arg2.fixed:
arg2 = check_constant(arg2.value, arg2, visitor)
if arg2.__class__ is InvalidNumber:
deprecation_warning(
f"Encountered {arg1}*{val2str(arg2)} in expression "
"tree. Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return False, (_CONSTANT, arg1)

ans = visitor.Result()
ans.linear[_id] = arg1
return False, (_LINEAR, ans)
visitor.record_monomial(visitor, ans, arg1, arg2),
if ans.linear:
return False, (_LINEAR, ans)
else:
return False, (_CONSTANT, ans.constant)

@staticmethod
def _before_linear(visitor, child):
var_map = visitor.var_map
ans = visitor.Result()
recorder = visitor.record_monomial
evaluate = visitor.evaluate
const = 0
linear = ans.linear
for arg in child.args:
if arg.__class__ is MonomialTermExpression:
arg1, arg2 = arg._args_
if arg1.__class__ not in native_types:
try:
arg1 = check_constant(visitor.evaluate(arg1), arg1, visitor)
arg1 = check_constant(evaluate(arg1), arg1, visitor)
except (ValueError, ArithmeticError):
return True, None

# Trap multiplication by 0 and nan. Note that arg1 was
# reduced to a numeric value at the beginning of this
# method.
if not arg1:
if arg2.fixed:
arg2 = check_constant(arg2.value, arg2, visitor)
if arg2.__class__ is InvalidNumber:
deprecation_warning(
f"Encountered {arg1}*{val2str(arg2)} in expression "
"tree. Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
continue

_id = id(arg2)
if _id not in var_map:
if arg2.fixed:
const += arg1 * check_constant(arg2.value, arg2, visitor)
continue
visitor.var_recorder.add(arg2)
linear[_id] = arg1
elif _id in linear:
linear[_id] += arg1
else:
linear[_id] = arg1
recorder(visitor, ans, arg1, arg2)
elif arg.__class__ in native_numeric_types:
const += arg
elif arg.is_variable_type():
_id = id(arg)
if _id not in var_map:
if arg.fixed:
const += check_constant(arg.value, arg, visitor)
continue
visitor.var_recorder.add(arg)
linear[_id] = 1
elif _id in linear:
linear[_id] += 1
else:
linear[_id] = 1
recorder(visitor, ans, 1, arg)
else:
# Fixed objects; e.g. Params & Param expressions
try:
const += check_constant(visitor.evaluate(arg), arg, visitor)
const += check_constant(evaluate(arg), arg, visitor)
except (ValueError, ArithmeticError):
return True, None
if linear:
ans.constant = const

ans.constant += const
if ans.linear:
return False, (_LINEAR, ans)
else:
return False, (_CONSTANT, const)
return False, (_CONSTANT, ans.constant)

@staticmethod
def _before_named_expression(visitor, child):
Expand Down Expand Up @@ -822,14 +784,15 @@ def __init__(
super().__init__()
self.subexpression_cache = subexpression_cache
if any(_ is not None for _ in (var_map, var_order, sorter)):
_name = self.__class__.__name__
if var_recorder is not None:
raise ValueError(
"LinearRepnVisitor: cannot specify any of var_map, "
f"{_name}: cannot specify any of var_map, "
"var_order, or sorter with var_recorder"
)
deprecation_warning(
"var_map, var_order, and sorter are deprecated arguments to "
"LinearRepnVisitor(). Please pass the VarRecorder object directly.",
f"{_name}(). Please pass the VarRecorder object directly.",
version='6.8.1',
)
var_recorder = OrderedVarRecorder(var_map, var_order, sorter)
Expand All @@ -839,6 +802,7 @@ def __init__(
)
self.var_recorder = var_recorder
self.var_map = var_recorder.var_map
self.record_monomial = self.before_child_dispatcher.record_monomial
self._eval_expr_visitor = _EvaluationVisitor(True)
self.evaluate = self._eval_expr_visitor.dfs_postorder_stack

Expand Down
9 changes: 7 additions & 2 deletions pyomo/repn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,18 @@ def test_monomial(self):
cfg = VisitorConfig()
with LoggingIntercept() as LOG:
repn = LinearRepnVisitor(**cfg).walk_expression(param_expr)
self.assertEqual(LOG.getvalue(), "")
self.assertRegex(
LOG.getvalue(),
r"DEPRECATED: Encountered 0\*InvalidNumber\(nan\) in expression tree.",
)

self.assertEqual(cfg.subexpr, {})
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
# This will be true after we remove the deprecation warning:
# self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.constant, 0)
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down
Loading