Skip to content
This repository has been archived by the owner on Feb 21, 2022. It is now read-only.

Commit

Permalink
Merged in miklos1/reference-cell (pull request #20)
Browse files Browse the repository at this point in the history
Smarter FIAT reference cells ("reference elements")
  • Loading branch information
miklos1 committed Jul 25, 2016
2 parents 1e26f31 + 8ae5f1b commit c5471a9
Show file tree
Hide file tree
Showing 7 changed files with 463 additions and 228 deletions.
52 changes: 24 additions & 28 deletions FIAT/finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from __future__ import absolute_import

import numpy
from six.moves import map

from FIAT.polynomial_set import PolynomialSet
from FIAT.quadrature import make_quadrature
from FIAT.reference_element import TensorProductCell
from FIAT.quadrature_schemes import create_quadrature


class FiniteElement(object):
Expand Down Expand Up @@ -156,24 +157,34 @@ def get_num_members(self, arg):
return self.get_nodal_basis().get_expansion_set().get_num_members(arg)


def _facet_support_dofs(elem, quad, facet_transform, facets):
"""Generic facet support dofs constructor.
def entity_support_dofs(elem, entity_dim):
"""Return the map of entity id to the degrees of freedom for which the
corresponding basis functions take non-zero values
:arg elem: FIAT finite element
:arg quad: Quadrature rule on the facet
:arg facet_transform: A function mapping a facet number onto a function
which maps coordinates on the facet onto coordinates on the cell.
:arg facets: Facet numbers to loop over.
:arg entity_dim: Dimension of the cell subentity.
"""
eps = 1.e-8 # Is this a safe value?
if not hasattr(elem, "_entity_support_dofs"):
elem._entity_support_dofs = {}
cache = elem._entity_support_dofs
try:
return cache[entity_dim]
except KeyError:
pass

weights = quad.get_weights()
ref_el = elem.get_reference_element()
dim = ref_el.get_spatial_dimension()

entity_cell = elem.ref_el.construct_subelement(entity_dim)
quad = create_quadrature(entity_cell, max(2*elem.degree(), 1))
weights = quad.get_weights()

eps = 1.e-8 # Is this a safe value?

result = {}
for f in facets:
points = map(facet_transform(f), quad.get_points())
for f in elem.entity_dofs()[entity_dim].keys():
entity_transform = elem.ref_el.get_entity_transform(entity_dim, f)
points = list(map(entity_transform, quad.get_points()))

# Integrate the square of the basis functions on the facet.
vals = numpy.double(elem.tabulate(0, points)[(0,) * dim])
Expand All @@ -187,20 +198,5 @@ def _facet_support_dofs(elem, quad, facet_transform, facets):

result[f] = [dof for dof, i in enumerate(ints) if i > eps]

cache[entity_dim] = result
return result


def facet_support_dofs(elem):
"""Return the map of facet id to the degrees of freedom for which the
corresponding basis functions take non-zero values."""
if not hasattr(elem, "_facet_support_dofs"):
# Non-extruded cells only
assert not isinstance(elem.ref_el, TensorProductCell)

q = make_quadrature(elem.ref_el.get_facet_element(), max(2*elem.degree(), 1))
ft = lambda f: elem.ref_el.get_facet_transform(f)
dim = elem.ref_el.get_spatial_dimension()
facets = elem.entity_dofs()[dim-1].keys()
elem._facet_support_dofs = _facet_support_dofs(elem, q, ft, facets)

return elem._facet_support_dofs
28 changes: 15 additions & 13 deletions FIAT/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from __future__ import absolute_import

import itertools
import math
import numpy

Expand All @@ -33,6 +34,9 @@ class QuadratureRule(object):
as the weighted sum of a function evaluated at a set of points."""

def __init__(self, ref_el, pts, wts):
if len(wts) != len(pts):
raise ValueError("Have %d weights, but %d points" % (len(wts), len(pts)))

self.ref_el = ref_el
self.pts = pts
self.wts = wts
Expand Down Expand Up @@ -220,28 +224,26 @@ def make_quadrature(ref_el, m):
msg = "Expecting at least one (not %d) quadrature point per direction" % min_m
assert (min_m > 0), msg

if ref_el.get_shape() == reference_element.LINE:
if ref_el.get_shape() == reference_element.POINT:
return QuadratureRule(ref_el, [()], [1])
elif ref_el.get_shape() == reference_element.LINE:
return GaussJacobiQuadratureLineRule(ref_el, m)
elif ref_el.get_shape() == reference_element.TRIANGLE:
return CollapsedQuadratureTriangleRule(ref_el, m)
elif ref_el.get_shape() == reference_element.TETRAHEDRON:
return CollapsedQuadratureTetrahedronRule(ref_el, m)
elif ref_el.get_shape() == reference_element.QUADRILATERAL:
quad_line = make_quadrature(reference_element.UFCInterval(), m)
return make_tensor_product_quadrature(quad_line, quad_line)
elif ref_el.get_shape() == reference_element.TENSORPRODUCT:
quadA = make_quadrature(ref_el.A, m[0])
quadB = make_quadrature(ref_el.B, m[1])
return make_tensor_product_quadrature(quadA, quadB)


def make_tensor_product_quadrature(quadA, quadB):
def make_tensor_product_quadrature(*quad_rules):
"""Returns the quadrature rule for a TensorProduct cell, by combining
the quadrature rules of the two components."""
ref_el = reference_element.TensorProductCell(quadA.ref_el, quadB.ref_el)
the quadrature rules of the components."""
ref_el = reference_element.TensorProductCell(*[q.ref_el
for q in quad_rules])
# Coordinates are "concatenated", weights are multiplied
pts = tuple([tuple(pt_a) + tuple(pt_b) for pt_a in quadA.pts for pt_b in quadB.pts])
wts = tuple([wt_a * wt_b for wt_a in quadA.wts for wt_b in quadB.wts])
pts = [list(itertools.chain(*pt_tuple))
for pt_tuple in itertools.product(*[q.pts for q in quad_rules])]
wts = [numpy.prod(wt_tuple)
for wt_tuple in itertools.product(*[q.wts for q in quad_rules])]
return QuadratureRule(ref_el, pts, wts)


Expand Down
17 changes: 13 additions & 4 deletions FIAT/quadrature_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from numpy import array, arange, float64

# FIAT
from FIAT.reference_element import TENSORPRODUCT, UFCTriangle, UFCTetrahedron
from FIAT.reference_element import QUADRILATERAL, TENSORPRODUCT, UFCTriangle, UFCTetrahedron
from FIAT.quadrature import QuadratureRule, make_quadrature, make_tensor_product_quadrature


Expand All @@ -64,9 +64,18 @@ def create_quadrature(ref_el, degree, scheme="default"):
integrate exactly.
"""
if ref_el.get_shape() == TENSORPRODUCT:
quadA = create_quadrature(ref_el.A, degree[0], scheme)
quadB = create_quadrature(ref_el.B, degree[1], scheme)
return make_tensor_product_quadrature(quadA, quadB)
try:
degree = tuple(degree)
except TypeError:
degree = (degree,) * len(ref_el.cells)

assert len(ref_el.cells) == len(degree)
quad_rules = [create_quadrature(c, d, scheme)
for c, d in zip(ref_el.cells, degree)]
return make_tensor_product_quadrature(*quad_rules)

if ref_el.get_shape() == QUADRILATERAL:
return create_quadrature(ref_el.product, degree, scheme)

if degree < 0:
raise ValueError("Need positive degree, not %d" % degree)
Expand Down

0 comments on commit c5471a9

Please sign in to comment.