Skip to content

Commit

Permalink
Merge pull request #3190 from Robbybp/iden-vars-namedexpr
Browse files Browse the repository at this point in the history
Exploit repeated named expressions in `identify_variables`
  • Loading branch information
mrmundt committed May 8, 2024
2 parents 7c49bee + 252ce4f commit f8459a7
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 252 deletions.
150 changes: 122 additions & 28 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,22 +1373,125 @@ def identify_components(expr, component_types):
# =====================================================


class _VariableVisitor(SimpleExpressionVisitor):
def __init__(self):
self.seen = set()
class _VariableVisitor(StreamBasedExpressionVisitor):
def __init__(self, include_fixed=False, named_expression_cache=None):
"""Visitor that collects all unique variables participating in an
expression
def visit(self, node):
if node.__class__ in nonpyomo_leaf_types:
return
Args:
include_fixed (bool): Whether to include fixed variables
named_expression_cache (optional, dict): Dict mapping ids of named
expressions to a tuple of the list of all variables and the
set of all variable ids contained in the named expression.
if node.is_variable_type():
if id(node) in self.seen:
return
self.seen.add(id(node))
return node
"""
super().__init__()
self._include_fixed = include_fixed
if named_expression_cache is None:
# This cache will map named expression ids to the
# tuple: ([variables], {variable ids})
named_expression_cache = {}
self._named_expression_cache = named_expression_cache
# Stack of active named expressions. This holds the id of
# expressions we are currently in.
self._active_named_expressions = []

def initializeWalker(self, expr):
if expr.__class__ in native_types:
return False, []
elif expr.is_named_expression_type():
eid = id(expr)
if eid in self._named_expression_cache:
# If we were given a named expression that is already cached,
# just do nothing and return the expression's variables
variables, var_set = self._named_expression_cache[eid]
return False, variables
else:
# We were given a named expression that is not cached.
# Initialize data structures and add this expression to the
# stack. This expression will get popped in exitNode.
self._variables = []
self._seen = set()
self._named_expression_cache[eid] = [], set()
self._active_named_expressions.append(eid)
return True, expr
elif expr.is_variable_type():
return False, [expr]
else:
self._variables = []
self._seen = set()
return True, expr

def identify_variables(expr, include_fixed=True):
def beforeChild(self, parent, child, index):
if child.__class__ in native_types:
return False, None
elif child.is_named_expression_type():
eid = id(child)
if eid in self._named_expression_cache:
# We have already encountered this named expression. We just add
# the cached variables to our list and don't descend.
if self._active_named_expressions:
# If we are in another named expression, we update the
# parent expression's cache. We don't need to update the
# global list as we will do this when we exit the active
# named expression.
parent_eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[parent_eid]
else:
# If we are not in a named expression, we update the global
# list.
variables = self._variables
var_set = self._seen
for var in self._named_expression_cache[eid][0]:
if id(var) not in var_set:
var_set.add(id(var))
variables.append(var)
return False, None
else:
# If we are descending into a new named expression, initialize
# a cache to store the expression's local variables.
self._named_expression_cache[id(child)] = ([], set())
self._active_named_expressions.append(id(child))
return True, None
elif child.is_variable_type() and (self._include_fixed or not child.fixed):
if self._active_named_expressions:
# If we are in a named expression, add new variables to the cache.
eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[eid]
else:
variables = self._variables
var_set = self._seen
if id(child) not in var_set:
var_set.add(id(child))
variables.append(child)
return False, None
else:
return True, None

def exitNode(self, node, data):
if node.is_named_expression_type():
# If we are returning from a named expression, we have at least one
# active named expression. We must make sure that we properly
# handle the variables for the named expression we just exited.
eid = self._active_named_expressions.pop()
if self._active_named_expressions:
# If we still are in a named expression, we update that expression's
# cache with any new variables encountered.
parent_eid = self._active_named_expressions[-1]
variables, var_set = self._named_expression_cache[parent_eid]
else:
variables = self._variables
var_set = self._seen
for var in self._named_expression_cache[eid][0]:
if id(var) not in var_set:
var_set.add(id(var))
variables.append(var)

def finalizeResult(self, result):
return self._variables


def identify_variables(expr, include_fixed=True, named_expression_cache=None):
"""
A generator that yields a sequence of variables
in an expression tree.
Expand All @@ -1402,22 +1505,13 @@ def identify_variables(expr, include_fixed=True):
Yields:
Each variable that is found.
"""
visitor = _VariableVisitor()
if include_fixed:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
yield from v
else:
yield v
else:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
for v_i in v:
if not v_i.is_fixed():
yield v_i
else:
if not v.is_fixed():
yield v
if named_expression_cache is None:
named_expression_cache = {}
visitor = _VariableVisitor(
named_expression_cache=named_expression_cache, include_fixed=include_fixed
)
variables = visitor.walk_expression(expr)
yield from variables


# =====================================================
Expand Down
17 changes: 12 additions & 5 deletions pyomo/core/tests/unit/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def test_identify_vars_vars(self):
self.assertEqual(list(identify_variables(m.a + m.b[1])), [m.a, m.b[1]])
self.assertEqual(list(identify_variables(m.a ** m.b[1])), [m.a, m.b[1]])
self.assertEqual(
list(identify_variables(m.a ** m.b[1] + m.b[2])), [m.b[2], m.a, m.b[1]]
ComponentSet(identify_variables(m.a ** m.b[1] + m.b[2])),
ComponentSet([m.b[2], m.a, m.b[1]]),
)
self.assertEqual(
list(identify_variables(m.a ** m.b[1] + m.b[2] * m.b[3] * m.b[2])),
Expand All @@ -159,14 +160,20 @@ def test_identify_vars_vars(self):
# Identify variables in the arguments to functions
#
self.assertEqual(
list(identify_variables(m.x(m.a, 'string_param', 1, []) * m.b[1])),
[m.b[1], m.a],
ComponentSet(identify_variables(m.x(m.a, 'string_param', 1, []) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
self.assertEqual(
list(identify_variables(m.x(m.p, 'string_param', 1, []) * m.b[1])), [m.b[1]]
)
self.assertEqual(list(identify_variables(tanh(m.a) * m.b[1])), [m.b[1], m.a])
self.assertEqual(list(identify_variables(abs(m.a) * m.b[1])), [m.b[1], m.a])
self.assertEqual(
ComponentSet(identify_variables(tanh(m.a) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
self.assertEqual(
ComponentSet(identify_variables(abs(m.a) * m.b[1])),
ComponentSet([m.b[1], m.a]),
)
#
# Check logic for allowing duplicates
#
Expand Down
Loading

0 comments on commit f8459a7

Please sign in to comment.