Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check_restrictions in compute_form_data #34

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion ufl/algorithms/check_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,34 @@
from ufl.corealg.multifunction import MultiFunction
from ufl.corealg.map_dag import map_expr_dag

from ufl.classes import Restricted


class RestrictionChecker(MultiFunction):
def __init__(self, require_restriction):
MultiFunction.__init__(self)
self.current_restriction = None
self.require_restriction = require_restriction

def find_restriction(self, ops):
if self.current_restriction is not None:
return
for op in ops.ufl_operands:
if isinstance(op, Restricted):
self.current_restriction = op.side()
break
self.find_restriction(op)

def expr(self, o):
pass
# Check is given expression involves restrictions
self.find_restriction(o)
# Restrictions are needed and only allowed for interior facet integrals
if self.require_restriction:
if self.current_restriction is None:
error("Form argument must be restricted in interior facet integrals.")
else:
if self.current_restriction is not None:
error("Restrictions are only allowed for interior facet and custom integral types.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hand-coded recursion in find_restriction unfortunately isn't a scalable approach [it's quadratic in tree depth]. Unfortunately, to fix this is a little bit of work, the map_dag infrastructure provides facility for DAG-based visiting and memoization of post-order traversal (child before parent), but not the preorder (parent before child) traversal we need here. I think this is a wider problem.

I think you can phrase this visit as post-order traversal, but it's a little fiddly.

Unrelated to this patch, reading this code I wondered how it works at all, since in def restricted it calls self.visit but visit is not defined anywhere.

I think this was not completely ported from the old ReuseTransformer infrastructure.

It's not previously been used in the normal pipeline (we have some error checking via apply_default_restrictions that is only applied to interior_facet integrals), so this is why we haven't ever noticed.


def restricted(self, o):
if self.current_restriction is not None:
Expand Down
16 changes: 15 additions & 1 deletion ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from ufl.log import error, info
from ufl.utils.sequences import max_degree
from ufl.measure import custom_integral_types

from ufl.classes import GeometricFacetQuantity, Coefficient, Form, FunctionSpace
from ufl.corealg.traversal import traverse_unique_terminals
from ufl.algorithms.analysis import extract_coefficients, extract_sub_elements, unique_tuple
from ufl.algorithms.formdata import FormData
from ufl.algorithms.formtransformations import compute_form_arities
from ufl.algorithms.check_arities import check_form_arity
from ufl.algorithms.check_restrictions import check_restrictions

# These are the main symbolic processing steps:
from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
Expand Down Expand Up @@ -140,7 +142,7 @@ def _check_elements(form_data):
for element in chain(form_data.unique_elements,
form_data.unique_sub_elements):
if element.family() is None:
error("Found element with undefined familty: %s" % repr(element))
error("Found element with undefined family: %s" % repr(element))
if element.cell() is None:
error("Found element with undefined cell: %s" % repr(element))

Expand Down Expand Up @@ -169,6 +171,17 @@ def _check_form_arity(preprocessed_form):
error("All terms in form must have same rank.")


def _check_restrictions(integral_data):
# Only allow restrictions on interior facet integrals
for itg_data in integral_data:
for itg in itg_data.integrals:
if itg_data.integral_type not in custom_integral_types: # Allowing custom integrals to pass
if itg_data.integral_type.startswith("interior_facet"):
check_restrictions(itg.integrand(), True)
else:
check_restrictions(itg.integrand(), False)


def _build_coefficient_replace_map(coefficients, element_mapping=None):
"""Create new Coefficient objects
with count starting at 0. Return mapping from old
Expand Down Expand Up @@ -399,6 +412,7 @@ def compute_form_data(form,
# --- Checks
_check_elements(self)
_check_facet_geometry(self.integral_data)
_check_restrictions(self.integral_data)

# TODO: This is a very expensive check... Replace with something
# faster!
Expand Down