Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for deprecation of a UFL function #549

Merged
merged 4 commits into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,23 @@ def jacobian(self, e, mt, tabledata, num_points):

def reference_cell_volume(self, e, mt, tabledata, access):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
return L.Symbol(f"{cellname}_reference_cell_volume")
else:
raise RuntimeError(f"Unhandled cell types {cellname}.")

def reference_facet_volume(self, e, mt, tabledata, access):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
return L.Symbol(f"{cellname}_reference_facet_volume")
else:
raise RuntimeError(f"Unhandled cell types {cellname}.")

def reference_normal(self, e, mt, tabledata, access):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_facet_normals")
facet = self.symbols.entity("facet", mt.restriction)
Expand All @@ -205,7 +205,7 @@ def reference_normal(self, e, mt, tabledata, access):

def cell_facet_jacobian(self, e, mt, tabledata, num_points):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_facet_jacobian")
facet = self.symbols.entity("facet", mt.restriction)
Expand All @@ -217,7 +217,7 @@ def cell_facet_jacobian(self, e, mt, tabledata, num_points):

def reference_cell_edge_vectors(self, e, mt, tabledata, num_points):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_edge_vectors")
return table[mt.component[0]][mt.component[1]]
Expand All @@ -228,7 +228,7 @@ def reference_cell_edge_vectors(self, e, mt, tabledata, num_points):

def reference_facet_edge_vectors(self, e, mt, tabledata, num_points):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("tetrahedron", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_edge_vectors")
facet = self.symbols.entity("facet", mt.restriction)
Expand All @@ -242,7 +242,7 @@ def reference_facet_edge_vectors(self, e, mt, tabledata, num_points):

def facet_orientation(self, e, mt, tabledata, num_points):
L = self.language
cellname = mt.terminal.ufl_domain().ufl_cell().cellname()
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname not in ("interval", "triangle", "tetrahedron"):
raise RuntimeError(f"Unhandled cell types {cellname}.")

Expand All @@ -252,7 +252,7 @@ def facet_orientation(self, e, mt, tabledata, num_points):

def cell_vertices(self, e, mt, tabledata, num_points):
# Get properties of domain
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domain(mt.terminal)
gdim = domain.geometric_dimension()
coordinate_element = convert_element(domain.ufl_coordinate_element())

Expand All @@ -275,7 +275,7 @@ def cell_vertices(self, e, mt, tabledata, num_points):

def cell_edge_vectors(self, e, mt, tabledata, num_points):
# Get properties of domain
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domain(mt.terminal)
cellname = domain.ufl_cell().cellname()
gdim = domain.geometric_dimension()
coordinate_element = convert_element(domain.ufl_coordinate_element())
Expand Down Expand Up @@ -316,7 +316,7 @@ def facet_edge_vectors(self, e, mt, tabledata, num_points):
L = self.language

# Get properties of domain
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domain(mt.terminal)
cellname = domain.ufl_cell().cellname()
gdim = domain.geometric_dimension()
coordinate_element = convert_element(domain.ufl_coordinate_element())
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc
L = self.language

# Get properties of domain
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domain(mt.terminal)
coordinate_element = domain.ufl_coordinate_element()
num_scalar_dofs = create_element(coordinate_element).sub_element.dim

Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate_geometry_tables(self, float_type: str):
if mt is not None:
t = type(mt.terminal)
if t in ufl_geometry:
cells[t].add(mt.terminal.ufl_domain().ufl_cell().cellname())
cells[t].add(ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname())

parts = []
for i, cell_list in cells.items():
Expand Down
3 changes: 2 additions & 1 deletion ffcx/codegeneration/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def generate_geometry_tables(self, float_type: str):
if mt is not None:
t = type(mt.terminal)
if t in ufl_geometry:
cells[t].add(mt.terminal.ufl_domain().ufl_cell().cellname())
cells[t].add(
ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname())

parts = []
for i, cell_list in cells.items():
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def format_mt_name(basename, mt):
# Format local derivatives
if mt.local_derivatives:
# Convert "listing" derivative multindex into "counting" representation
gdim = mt.terminal.ufl_domain().geometric_dimension()
gdim = ufl.domain.extract_unique_domain(mt.terminal).geometric_dimension()
ld_counting = ufl.utils.derivativetuples.derivative_listing_to_counts(mt.local_derivatives, gdim)
der = f"_d{''.join(map(str, ld_counting))}"
access += der
Expand Down
4 changes: 2 additions & 2 deletions ffcx/ir/analysis/valuenumbering.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def _modified_terminal(self, v):
num_gd = len(mt.global_derivatives)
assert not (num_ld and num_gd)
if num_ld:
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domain(mt.terminal)
tdim = domain.topological_dimension()
d_components = ufl.permutation.compute_indices((tdim, ) * num_ld)
elif num_gd:
domain = mt.terminal.ufl_domain()
domain = ufl.domain.extract_unique_domiain(mt.terminal)
gdim = domain.geometric_dimension()
d_components = ufl.permutation.compute_indices((gdim, ) * num_gd)
else:
Expand Down
8 changes: 4 additions & 4 deletions ffcx/ir/elementtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate_psi_table_name(quadrature_rule, element_counter, averaged: str, ent
def get_modified_terminal_element(mt) -> typing.Optional[ModifiedTerminalElement]:
gd = mt.global_derivatives
ld = mt.local_derivatives

domain = ufl.domain.extract_unique_domain(mt.terminal)
# Extract element from FormArguments and relevant GeometricQuantities
if isinstance(mt.terminal, ufl.classes.FormArgument):
if gd and mt.reference_value:
Expand All @@ -203,7 +203,7 @@ def get_modified_terminal_element(mt) -> typing.Optional[ModifiedTerminalElement
raise RuntimeError("Not expecting reference value of x.")
if gd:
raise RuntimeError("Not expecting global derivatives of x.")
element = convert_element(mt.terminal.ufl_domain().ufl_coordinate_element())
element = convert_element(domain.ufl_coordinate_element())
if not ld:
fc = mt.flat_component
else:
Expand All @@ -216,7 +216,7 @@ def get_modified_terminal_element(mt) -> typing.Optional[ModifiedTerminalElement
raise RuntimeError("Not expecting reference value of J.")
if gd:
raise RuntimeError("Not expecting global derivatives of J.")
element = convert_element(mt.terminal.ufl_domain().ufl_coordinate_element())
element = convert_element(domain.ufl_coordinate_element())
assert len(mt.component) == 2
# Translate component J[i,d] to x element context rgrad(x[i])[d]
fc, d = mt.component # x-component, derivative
Expand All @@ -226,7 +226,7 @@ def get_modified_terminal_element(mt) -> typing.Optional[ModifiedTerminalElement

assert (mt.averaged is None) or not (ld or gd)
# Change derivatives format for table lookup
gdim = mt.terminal.ufl_domain().geometric_dimension()
gdim = domain.geometric_dimension()
local_derivatives = ufl.utils.derivativetuples.derivative_listing_to_counts(
ld, gdim)

Expand Down
2 changes: 1 addition & 1 deletion ffcx/ir/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def _compute_expression_ir(expression, index, prefix, analysis, options, visuali
expression = expression[0]

try:
cell = expression.ufl_domain().ufl_cell()
cell = ufl.domain.extract_unique_domain(expression).ufl_cell()
except AttributeError:
# This case corresponds to a spatially constant expression
# without any dependencies
Expand Down