Skip to content

Commit

Permalink
complex: Propagate type correctly for Indexed in type checker (#20)
Browse files Browse the repository at this point in the history
* complex: Propagate type correctly for Indexed in type checker

* Don't try to take sign of a complex

* Move complex node removal to the correct spot

Co-authored-by: David Ham <David.Ham@imperial.ac.uk>
  • Loading branch information
wence- and dham committed May 12, 2020
1 parent f042055 commit 88933f7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
7 changes: 5 additions & 2 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,11 @@ def power(self, o, fp, gp):

def abs(self, o, df):
f, = o.ufl_operands
# return conditional(eq(f, 0), 0, Product(sign(f), df))
return sign(f) * df
# return conditional(eq(f, 0), 0, Product(sign(f), df)) abs is
# not complex differentiable, so we workaround the case of a
# real F in complex mode by defensively casting to real inside
# the sign.
return sign(Real(f)) * df

# --- Complex algebra

Expand Down
5 changes: 5 additions & 0 deletions ufl/algorithms/comparison_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def terminal(self, term, *ops):
self.nodetype[term] = 'complex'
return term

def indexed(self, o, expr, multiindex):
o = self.reuse_if_untouched(o, expr, multiindex)
self.nodetype[o] = self.nodetype[expr]
return o


def do_comparison_check(form):
"""Raises an error if invalid comparison nodes exist"""
Expand Down
8 changes: 4 additions & 4 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def compute_form_data(form,
if do_apply_restrictions:
form = apply_restrictions(form)

# If in real mode, remove any complex nodes introduced during form processing.
if not complex_mode:
form = remove_complex_nodes(form)

# --- Group integrals into IntegralData objects
# Most of the heavy lifting is done above in group_form_integrals.
self.integral_data = build_integral_data(form.integrals())
Expand Down Expand Up @@ -400,10 +404,6 @@ def compute_form_data(form,
# faster!
preprocessed_form = reconstruct_form_from_integral_data(self.integral_data)

# If in real mode, remove complex nodes entirely.
if not complex_mode:
preprocessed_form = remove_complex_nodes(preprocessed_form)

check_form_arity(preprocessed_form, self.original_form.arguments(), complex_mode) # Currently testing how fast this is

# TODO: This member is used by unit tests, change the tests to
Expand Down

0 comments on commit 88933f7

Please sign in to comment.