Skip to content

Commit

Permalink
Extend replace to be able to handle non-terminals (#15)
Browse files Browse the repository at this point in the history
* Extend replace to be able to handle non-terminals

* update docstring to reflect relaxed constraints

Co-authored-by: Lawrence Mitchell <wence@gmx.li>
  • Loading branch information
dham and wence- committed Jan 29, 2020
1 parent 287c42a commit 4516a6d
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions ufl/algorithms/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,23 @@

class Replacer(MultiFunction):
def __init__(self, mapping):
MultiFunction.__init__(self)
self._mapping = mapping
if not all(k._ufl_is_terminal_ for k in mapping.keys()):
error("This implementation can only replace Terminal objects.")
super().__init__()
self.mapping = mapping
if not all(k.ufl_shape == v.ufl_shape for k, v in mapping.items()):
error("Replacement expressions must have the same shape as what they replace.")

expr = MultiFunction.reuse_if_untouched

def terminal(self, o):
e = self._mapping.get(o)
if e is None:
return o
else:
return e
def expr(self, o, *args):
try:
return self.mapping[o]
except KeyError:
return self.reuse_if_untouched(o, *args)

def coefficient_derivative(self, o):
error("Derivatives should be applied before executing replace.")


def replace(e, mapping):
"""Replace terminal objects in expression.
"""Replace subexpressions in expression.
@param e:
An Expr or Form.
Expand All @@ -50,6 +45,14 @@ def replace(e, mapping):
mapping2 = dict((k, as_ufl(v)) for (k, v) in mapping.items())

# Workaround for problem with delayed derivative evaluation
# The problem is that J = derivative(f(g, h), g) does not evaluate immediately
# So if we subsequently do replace(J, {g: h}) we end up with an expression:
# derivative(f(h, h), h)
# rather than what were were probably thinking of:
# replace(derivative(f(g, h), g), {g: h})
#
# To fix this would require one to expand derivatives early (which
# is not attractive), or make replace lazy too.
if has_exact_type(e, CoefficientDerivative):
# Hack to avoid circular dependencies
from ufl.algorithms.ad import expand_derivatives
Expand Down

0 comments on commit 4516a6d

Please sign in to comment.