Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploit repeated named expressions in identify_variables #3190

Merged
merged 25 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
313a310
initial implementation of variable visitor that can exploit named exp…
Robbybp Mar 7, 2024
e9e63a0
use new variable visitor in get_vars_from_components rather than iden…
Robbybp Mar 11, 2024
38c78a3
remove unnecessary walker callbacks
Robbybp Mar 11, 2024
f32bd2e
method args on single line
Robbybp Mar 12, 2024
87644ca
update condition for skipping named expression in variable visitor
Robbybp Mar 12, 2024
e6e7259
super.__init__ call in variable visitor
Robbybp Mar 12, 2024
591f899
Merge branch 'main' of https://github.com/pyomo/pyomo into iden-vars-…
Robbybp Mar 14, 2024
e1fa256
[WIP] initial attempt at implementing identify_variables with a named…
Robbybp Mar 14, 2024
5ea6ae7
potentially working implementation of identify-variables with efficie…
Robbybp Mar 15, 2024
a82c750
update identify_variables tests to use ComponentSet to not rely on va…
Robbybp Mar 15, 2024
f8299f6
remove commented code and old identify_variables implementation
Robbybp Mar 15, 2024
d232fca
handle variable at root in initializeWalker rather than exitNode
Robbybp Mar 15, 2024
796ccef
remove previous vars_from_expressions implementation
Robbybp Mar 15, 2024
7c130ef
add docstring and comments to _StreamVariableVisitor
Robbybp Mar 15, 2024
0e3015d
arguments on single line
Robbybp Mar 15, 2024
9f8a155
Merge branch 'main' of https://github.com/pyomo/pyomo into iden-vars-…
Robbybp Mar 15, 2024
c64dcf4
consolidate logic for adding variable to set
Robbybp Mar 16, 2024
fed3c33
fix typo
Robbybp Mar 18, 2024
9fd202e
update GDP baselines to reflect change in variable order?
Robbybp Mar 19, 2024
ae439ad
use get_vars_from_components in create_subsystem_block
Robbybp Mar 19, 2024
be31a20
formatting fix
Robbybp Mar 19, 2024
08b9d93
remove unused imports
Robbybp Mar 19, 2024
443826d
remove old _VariableVisitor and rename new visitor to _VariableVisitor
Robbybp Mar 19, 2024
22e11e0
Merge branch 'main' into iden-vars-namedexpr
Robbybp Mar 19, 2024
252ce4f
Merge branch 'main' into iden-vars-namedexpr
blnicho May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 122 additions & 28 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,22 +1373,125 @@
# =====================================================


class _VariableVisitor(SimpleExpressionVisitor):
def __init__(self):
self.seen = set()
class _VariableVisitor(StreamBasedExpressionVisitor):
def __init__(self, include_fixed=False, named_expression_cache=None):
"""Visitor that collects all unique variables participating in an
expression

def visit(self, node):
if node.__class__ in nonpyomo_leaf_types:
return
Args:
include_fixed (bool): Whether to include fixed variables
named_expression_cache (optional, dict): Dict mapping ids of named
expressions to a tuple of the list of all variables and the
set of all variable ids contained in the named expression.

if node.is_variable_type():
if id(node) in self.seen:
return
self.seen.add(id(node))
return node
"""
super().__init__()
self._include_fixed = include_fixed
if named_expression_cache is None:
# This cache will map named expression ids to the
# tuple: ([variables], {variable ids})
named_expression_cache = {}

Check warning on line 1393 in pyomo/core/expr/visitor.py

View check run for this annotation

Codecov / codecov/patch

pyomo/core/expr/visitor.py#L1393

Added line #L1393 was not covered by tests
self._named_expression_cache = named_expression_cache
# Stack of active named expressions. This holds the id of
# expressions we are currently in.
self._active_named_expressions = []

def initializeWalker(self, expr):
if expr.__class__ in native_types:
return False, []
elif expr.is_named_expression_type():
eid = id(expr)
if eid in self._named_expression_cache:
# If we were given a named expression that is already cached,
# just do nothing and return the expression's variables
variables, var_set = self._named_expression_cache[eid]
return False, variables

Check warning on line 1408 in pyomo/core/expr/visitor.py

View check run for this annotation

Codecov / codecov/patch

pyomo/core/expr/visitor.py#L1407-L1408

Added lines #L1407 - L1408 were not covered by tests
else:
# We were given a named expression that is not cached.
# Initialize data structures and add this expression to the
# stack. This expression will get popped in exitNode.
self._variables = []
self._seen = set()
self._named_expression_cache[eid] = [], set()
self._active_named_expressions.append(eid)
return True, expr
elif expr.is_variable_type():
return False, [expr]
else:
self._variables = []
self._seen = set()
return True, expr

def identify_variables(expr, include_fixed=True):
def beforeChild(self, parent, child, index):
if child.__class__ in native_types:
return False, None
elif child.is_named_expression_type():
eid = id(child)
if eid in self._named_expression_cache:
# We have already encountered this named expression. We just add
# the cached variables to our list and don't descend.
if self._active_named_expressions:
# If we are in another named expression, we update the
# parent expression's cache. We don't need to update the
# global list as we will do this when we exit the active
# named expression.
parent_eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[parent_eid]
else:
# If we are not in a named expression, we update the global
# list.
variables = self._variables
var_set = self._seen
for var in self._named_expression_cache[eid][0]:
if id(var) not in var_set:
var_set.add(id(var))
variables.append(var)
return False, None
else:
# If we are descending into a new named expression, initialize
# a cache to store the expression's local variables.
self._named_expression_cache[id(child)] = ([], set())
self._active_named_expressions.append(id(child))
return True, None
elif child.is_variable_type() and (self._include_fixed or not child.fixed):
if self._active_named_expressions:
# If we are in a named expression, add new variables to the cache.
eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[eid]
else:
variables = self._variables
var_set = self._seen
if id(child) not in var_set:
var_set.add(id(child))
variables.append(child)
return False, None
else:
return True, None

def exitNode(self, node, data):
if node.is_named_expression_type():
# If we are returning from a named expression, we have at least one
# active named expression. We must make sure that we properly
# handle the variables for the named expression we just exited.
eid = self._active_named_expressions.pop()
if self._active_named_expressions:
# If we still are in a named expression, we update that expression's
# cache with any new variables encountered.
parent_eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[parent_eid]
else:
variables = self._variables
var_set = self._seen
for var in self._named_expression_cache[eid][0]:
if id(var) not in var_set:
var_set.add(id(var))
variables.append(var)

def finalizeResult(self, result):
return self._variables


def identify_variables(expr, include_fixed=True, named_expression_cache=None):
"""
A generator that yields a sequence of variables
in an expression tree.
Expand All @@ -1402,22 +1505,13 @@
Yields:
Each variable that is found.
"""
visitor = _VariableVisitor()
if include_fixed:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
yield from v
else:
yield v
else:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
for v_i in v:
if not v_i.is_fixed():
yield v_i
else:
if not v.is_fixed():
yield v
if named_expression_cache is None:
named_expression_cache = {}
visitor = _VariableVisitor(
named_expression_cache=named_expression_cache, include_fixed=include_fixed
)
variables = visitor.walk_expression(expr)
yield from variables


# =====================================================
Expand Down
17 changes: 12 additions & 5 deletions pyomo/core/tests/unit/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def test_identify_vars_vars(self):
self.assertEqual(list(identify_variables(m.a + m.b[1])), [m.a, m.b[1]])
self.assertEqual(list(identify_variables(m.a ** m.b[1])), [m.a, m.b[1]])
self.assertEqual(
list(identify_variables(m.a ** m.b[1] + m.b[2])), [m.b[2], m.a, m.b[1]]
ComponentSet(identify_variables(m.a ** m.b[1] + m.b[2])),
ComponentSet([m.b[2], m.a, m.b[1]]),
)
self.assertEqual(
list(identify_variables(m.a ** m.b[1] + m.b[2] * m.b[3] * m.b[2])),
Expand All @@ -159,14 +160,20 @@ def test_identify_vars_vars(self):
# Identify variables in the arguments to functions
#
self.assertEqual(
list(identify_variables(m.x(m.a, 'string_param', 1, []) * m.b[1])),
[m.b[1], m.a],
ComponentSet(identify_variables(m.x(m.a, 'string_param', 1, []) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
self.assertEqual(
list(identify_variables(m.x(m.p, 'string_param', 1, []) * m.b[1])), [m.b[1]]
)
self.assertEqual(list(identify_variables(tanh(m.a) * m.b[1])), [m.b[1], m.a])
self.assertEqual(list(identify_variables(abs(m.a) * m.b[1])), [m.b[1], m.a])
self.assertEqual(
ComponentSet(identify_variables(tanh(m.a) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
self.assertEqual(
ComponentSet(identify_variables(abs(m.a) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
#
# Check logic for allowing duplicates
#
Expand Down
Loading