Skip to content

Commit

Permalink
Merge pull request #14 from FEniCS/wence/fix/traversal-slow
Browse files Browse the repository at this point in the history
Replace hand-coded stack in traversal with growable one
  • Loading branch information
wence- committed Nov 7, 2019
2 parents afcc31c + 6053c43 commit 287c42a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 142 deletions.
43 changes: 14 additions & 29 deletions ufl/core/compute_expr_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Non-recursive traversal-based hash computation algorithm.
Fast iteration over nodes in an ``Expr`` DAG to compute
memorized hashes for all unique nodes.
memoized hashes for all unique nodes.
"""

# Copyright (C) 2015 Martin Sandve Alnæs
Expand All @@ -13,38 +13,23 @@
#
# Modified by Massimiliano Leoni, 2016

# This limits the _depth_ of expression trees
_recursion_limit_ = 6400 # should be enough for everyone


def compute_expr_hash(expr):
"""Compute hashes of *expr* and all its nodes efficiently, without using Python recursion."""
if expr._hash is not None:
return expr._hash

stack = [None] * _recursion_limit_
stacksize = 0

ops = expr.ufl_operands
stack[stacksize] = [expr, ops, len(ops)]
stacksize += 1

while stacksize > 0:
entry = stack[stacksize - 1]
e = entry[0]
if e._hash is not None:
# cutoff: don't need to visit children when hash has previously been computed
stacksize -= 1
elif entry[2] == 0:
# all children consumed: trigger memoized hash computation
e._hash = e._ufl_compute_hash_()
stacksize -= 1
# Postorder traversal, can't use unique_post_traversal, since that
# uses a set which requires that this hash is computed.
lifo = [(expr, list(expr.ufl_operands))]
while lifo:
expr, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None and dep._hash is None:
lifo.append((dep, list(dep.ufl_operands)))
deps[i] = None
break
else:
# add children to stack to hash them first
entry[2] -= 1
o = entry[1][entry[2]]
oops = o.ufl_operands
stack[stacksize] = [o, oops, len(oops)]
stacksize += 1

if expr._hash is None:
expr._hash = expr._ufl_compute_hash_()
lifo.pop()
return expr._hash
179 changes: 66 additions & 113 deletions ufl/corealg/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,171 +13,124 @@
#
# Modified by Massimiliano Leoni, 2016

# This limits the _depth_ of expression trees
_recursion_limit_ = 6400 # should be enough for everyone


def pre_traversal(expr):
"""Yield ``o`` for each tree node ``o`` in *expr*, parent before child."""
stack = [None] * _recursion_limit_
stack[0] = expr
stacksize = 1
while stacksize > 0:
stacksize -= 1
expr = stack[stacksize]
lifo = [expr]
while lifo:
expr = lifo.pop()
yield expr
for op in expr.ufl_operands:
stack[stacksize] = op
stacksize += 1
lifo.append(op)


def post_traversal(expr):
"""Yield ``o`` for each node ``o`` in *expr*, child before parent."""
stack = [None] * _recursion_limit_
stacksize = 0

ops = expr.ufl_operands
stack[stacksize] = [expr, ops, len(ops)]
stacksize += 1

while stacksize > 0:
entry = stack[stacksize - 1]
if entry[2] == 0:
yield entry[0]
stacksize -= 1
lifo = [(expr, list(reversed(expr.ufl_operands)))]
while lifo:
expr, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None:
lifo.append((dep, list(reversed(dep.ufl_operands))))
deps[i] = None
break
else:
entry[2] -= 1
o = entry[1][entry[2]]
oops = o.ufl_operands
stack[stacksize] = [o, oops, len(oops)]
stacksize += 1
yield expr
lifo.pop()


def cutoff_post_traversal(expr, cutofftypes):
"""Yield ``o`` for each node ``o`` in *expr*, child before parent, but
skipping subtrees of the cutofftypes."""
stack = [None] * _recursion_limit_
stacksize = 0

ops = expr.ufl_operands
stack[stacksize] = [expr, ops, len(ops)]
stacksize += 1

while stacksize > 0:
entry = stack[stacksize - 1]
expr = entry[0]
if entry[2] == 0 or cutofftypes[expr._ufl_typecode_]:
lifo = [(expr, list(reversed(expr.ufl_operands)))]
while lifo:
expr, deps = lifo[-1]
if cutofftypes[expr._ufl_typecode_]:
yield expr
stacksize -= 1
lifo.pop()
else:
entry[2] -= 1
o = entry[1][entry[2]]
if cutofftypes[expr._ufl_typecode_]:
oops = ()
for i, dep in enumerate(deps):
if dep is not None:
lifo.append((dep, list(reversed(dep.ufl_operands))))
deps[i] = None
break
else:
oops = o.ufl_operands
stack[stacksize] = [o, oops, len(oops)]
stacksize += 1
yield expr
lifo.pop()


def unique_pre_traversal(expr, visited=None):
"""Yield ``o`` for each tree node ``o`` in *expr*, parent before child.
This version only visits each node once.
"""
stack = [None] * _recursion_limit_
stack[0] = expr
stacksize = 1
if visited is None:
visited = set()
while stacksize > 0:
stacksize -= 1
expr = stack[stacksize]
if expr not in visited:
visited.add(expr)
yield expr
for op in expr.ufl_operands:
stack[stacksize] = op
stacksize += 1
lifo = [expr]
visited.add(expr)

while lifo:
expr = lifo.pop()
yield expr
for op in expr.ufl_operands:
if op not in visited:
lifo.append(op)
visited.add(op)


def unique_post_traversal(expr, visited=None):
"""Yield ``o`` for each node ``o`` in *expr*, child before parent.
Never visit a node twice."""
stack = [None] * _recursion_limit_
stack[0] = (expr, list(expr.ufl_operands))
stacksize = 1
lifo = [(expr, list(expr.ufl_operands))]
if visited is None:
visited = set()
while stacksize > 0:
expr, ops = stack[stacksize - 1]
for i, o in enumerate(ops):
if o is not None and o not in visited:
stack[stacksize] = (o, list(o.ufl_operands))
stacksize += 1
ops[i] = None
visited.add(expr)
while lifo:
expr, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None and dep not in visited:
lifo.append((dep, list(dep.ufl_operands)))
deps[i] = None
break
else:
yield expr
visited.add(expr)
stacksize -= 1
lifo.pop()


def cutoff_unique_post_traversal(expr, cutofftypes, visited=None):
"""Yield ``o`` for each node ``o`` in *expr*, child before parent.
Never visit a node twice."""
stack = [None] * _recursion_limit_
stack[0] = (expr, () if cutofftypes[expr._ufl_typecode_] else list(expr.ufl_operands))
stacksize = 1
lifo = [(expr, list(reversed(expr.ufl_operands)))]
if visited is None:
visited = set()
while stacksize > 0:
expr, ops = stack[stacksize - 1]
for i, o in enumerate(ops):
if o is not None and o not in visited:
stack[stacksize] = (o, () if cutofftypes[o._ufl_typecode_] else list(o.ufl_operands))
stacksize += 1
ops[i] = None
break
else:
while lifo:
expr, deps = lifo[-1]
if cutofftypes[expr._ufl_typecode_]:
yield expr
visited.add(expr)
stacksize -= 1
lifo.pop()
else:
for i, dep in enumerate(deps):
if dep is not None and dep not in visited:
lifo.append((dep, list(reversed(dep.ufl_operands))))
deps[i] = None
break
else:
yield expr
visited.add(expr)
lifo.pop()


def traverse_terminals(expr):
"Iterate over all terminal objects in *expr*, including duplicates."
stack = [None] * _recursion_limit_
stack[0] = expr
stacksize = 1
while stacksize > 0:
stacksize -= 1
expr = stack[stacksize]
if expr._ufl_is_terminal_:
yield expr
else:
for op in expr.ufl_operands:
stack[stacksize] = op
stacksize += 1
for op in pre_traversal(expr):
if op._ufl_is_terminal_:
yield op


def traverse_unique_terminals(expr, visited=None):
"Iterate over all terminal objects in *expr*, not including duplicates."
stack = [None] * _recursion_limit_
stack[0] = expr
stacksize = 1
if visited is None:
visited = set()
while stacksize > 0:
stacksize -= 1
expr = stack[stacksize]
if expr not in visited:
visited.add(expr)
if expr._ufl_is_terminal_:
yield expr
else:
for op in expr.ufl_operands:
stack[stacksize] = op
stacksize += 1
for op in unique_pre_traversal(expr, visited=visited):
if op._ufl_is_terminal_:
yield op

0 comments on commit 287c42a

Please sign in to comment.