Skip to content
This repository has been archived by the owner on Feb 21, 2022. It is now read-only.

Commit

Permalink
Merged in miklos1/bernstein (pull request #51)
Browse files Browse the repository at this point in the history
Add Bernstein element

Approved-by: Lawrence Mitchell <wence@gmx.li>
Approved-by: rckirby <robert_kirby@baylor.edu>
  • Loading branch information
miklos1 committed Oct 30, 2018
2 parents 691d3fa + c4b7b1b commit f2a3059
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
2 changes: 2 additions & 0 deletions FIAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Import finite element classes
from FIAT.finite_element import FiniteElement, CiarletElement # noqa: F401
from FIAT.argyris import Argyris
from FIAT.bernstein import Bernstein
from FIAT.bell import Bell
from FIAT.argyris import QuinticArgyris
from FIAT.brezzi_douglas_marini import BrezziDouglasMarini
Expand Down Expand Up @@ -47,6 +48,7 @@
# List of supported elements and mapping to element classes
supported_elements = {"Argyris": Argyris,
"Bell": Bell,
"Bernstein": Bernstein,
"Brezzi-Douglas-Marini": BrezziDouglasMarini,
"Brezzi-Douglas-Fortin-Marini": BrezziDouglasFortinMarini,
"Bubble": Bubble,
Expand Down
199 changes: 199 additions & 0 deletions FIAT/bernstein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Miklós Homolya
#
# This file is part of FIAT.
#
# FIAT is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# FIAT is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with FIAT. If not, see <https://www.gnu.org/licenses/>.

import math
import numpy

from FIAT.finite_element import FiniteElement
from FIAT.dual_set import DualSet
from FIAT.polynomial_set import mis


class BernsteinDualSet(DualSet):
"""The dual basis for Bernstein elements."""

def __init__(self, ref_el, degree):
# Initialise data structures
topology = ref_el.get_topology()
entity_ids = {dim: {entity_i: []
for entity_i in entities}
for dim, entities in topology.items()}

# Calculate inverse topology
inverse_topology = {vertices: (dim, entity_i)
for dim, entities in topology.items()
for entity_i, vertices in entities.items()}

# Generate triangular barycentric indices
dim = ref_el.get_spatial_dimension()
kss = mis(dim + 1, degree)

# Fill data structures
nodes = []
for i, ks in enumerate(kss):
vertices, = numpy.nonzero(ks)
entity_dim, entity_i = inverse_topology[tuple(vertices)]
entity_ids[entity_dim][entity_i].append(i)

# Leave nodes unimplemented for now
nodes.append(None)

super(BernsteinDualSet, self).__init__(nodes, ref_el, entity_ids)


class Bernstein(FiniteElement):
"""A finite element with Bernstein polynomials as basis functions."""

def __init__(self, ref_el, degree):
dual = BernsteinDualSet(ref_el, degree)
k = 0 # 0-form
super(Bernstein, self).__init__(ref_el, dual, degree, k)

def degree(self):
"""The degree of the polynomial space."""
return self.get_order()

def value_shape(self):
"""The value shape of the finite element functions."""
return ()

def tabulate(self, order, points, entity=None):
"""Return tabulated values of derivatives up to given order of
basis functions at given points.
:arg order: The maximum order of derivative.
:arg points: An iterable of points.
:arg entity: Optional (dimension, entity number) pair
indicating which topological entity of the
reference element to tabulate on. If ``None``,
default cell-wise tabulation is performed.
"""
# Transform points to reference cell coordinates
ref_el = self.get_reference_element()
if entity is None:
entity = (ref_el.get_spatial_dimension(), 0)

entity_dim, entity_id = entity
entity_transform = ref_el.get_entity_transform(entity_dim, entity_id)
cell_points = list(map(entity_transform, points))

# Construct Cartesian to Barycentric coordinate mapping
vs = numpy.asarray(ref_el.get_vertices())
B2R = numpy.vstack([vs.T, numpy.ones(len(vs))])
R2B = numpy.linalg.inv(B2R)

B = numpy.hstack([cell_points,
numpy.ones((len(cell_points), 1))]).dot(R2B.T)

# Evaluate everything
deg = self.degree()
dim = ref_el.get_spatial_dimension()
raw_result = {(alpha, i): vec
for i, ks in enumerate(mis(dim + 1, deg))
for o in range(order + 1)
for alpha, vec in bernstein_Dx(B, ks, o, R2B).items()}

# Rearrange result
space_dim = self.space_dimension()
dtype = numpy.array(list(raw_result.values())).dtype
result = {alpha: numpy.zeros((space_dim, len(cell_points)), dtype=dtype)
for o in range(order + 1)
for alpha in mis(dim, o)}
for (alpha, i), vec in raw_result.items():
result[alpha][i, :] = vec
return result


def bernstein_db(points, ks, alpha=None):
"""Evaluates Bernstein polynomials or its derivative at barycentric
points.
:arg points: array of points in barycentric coordinates
:arg ks: exponents defining the Bernstein polynomial
:arg alpha: derivative tuple
:returns: array of Bernstein polynomial values at given points.
"""
points = numpy.asarray(points)
ks = numpy.array(tuple(ks))

N, d_1 = points.shape
assert d_1 == len(ks)

if alpha is None:
alpha = numpy.zeros(d_1)
else:
alpha = numpy.array(tuple(alpha))
assert d_1 == len(alpha)

ls = ks - alpha
if any(k < 0 for k in ls):
return numpy.zeros(len(points))
elif all(k == 0 for k in ls):
return numpy.ones(len(points))
else:
# Calculate coefficient
coeff = math.factorial(ks.sum())
for k in ls:
coeff //= math.factorial(k)
return coeff * numpy.prod(points**ls, axis=1)


def bernstein_Dx(points, ks, order, R2B):
"""Evaluates Bernstein polynomials or its derivatives according to
reference coordinates.
:arg points: array of points in BARYCENTRIC COORDINATES
:arg ks: exponents defining the Bernstein polynomial
:arg alpha: derivative order (returns all derivatives of this
specified order)
:arg R2B: linear mapping from reference to barycentric coordinates
:returns: dictionary mapping from derivative tuples to arrays of
Bernstein polynomial values at given points.
"""
points = numpy.asarray(points)
ks = tuple(ks)

N, d_1 = points.shape
assert d_1 == len(ks)

# Collect derivatives according to barycentric coordinates
Db_map = {alpha: bernstein_db(points, ks, alpha)
for alpha in mis(d_1, order)}

# Arrange derivative tensor (barycentric coordinates)
dtype = numpy.array(list(Db_map.values())).dtype
Db_shape = (d_1,) * order
Db_tensor = numpy.empty(Db_shape + (N,), dtype=dtype)
for ds in numpy.ndindex(Db_shape):
alpha = [0] * d_1
for d in ds:
alpha[d] += 1
Db_tensor[ds + (slice(None),)] = Db_map[tuple(alpha)]

# Coordinate transformation: barycentric -> reference
result = {}
for alpha in mis(d_1 - 1, order):
values = Db_tensor
for d, k in enumerate(alpha):
for _ in range(k):
values = R2B[:, d].dot(values)
result[alpha] = values
return result
85 changes: 85 additions & 0 deletions test/unit/test_bernstein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Miklós Homolya
#
# This file is part of FIAT.
#
# FIAT is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# FIAT is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with FIAT. If not, see <https://www.gnu.org/licenses/>.

import numpy
import pytest

from FIAT.reference_element import ufc_simplex
from FIAT.bernstein import Bernstein
from FIAT.quadrature_schemes import create_quadrature


D02 = numpy.array([
[0.65423405, 1.39160021, 0.65423405, 3.95416573, 1.39160021, 3.95416573],
[3.95416573, 3.95416573, 1.39160021, 1.39160021, 0.65423405, 0.65423405],
[0.0831321, -2.12896637, 2.64569763, -7.25409741, 1.17096531, -6.51673126],
[0., 0., 0., 0., 0., 0.],
[-7.90833147, -7.90833147, -2.78320042, -2.78320042, -1.30846811, -1.30846811],
[-2.12896637, 0.0831321, -7.25409741, 2.64569763, -6.51673126, 1.17096531],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[3.95416573, 3.95416573, 1.39160021, 1.39160021, 0.65423405, 0.65423405],
[1.39160021, 0.65423405, 3.95416573, 0.65423405, 3.95416573, 1.39160021],
])

D11 = numpy.array([
[0.65423405, 1.39160021, 0.65423405, 3.95416573, 1.39160021, 3.95416573],
[3.29993168, 2.56256552, 0.73736616, -2.56256552, -0.73736616, -3.29993168],
[0.73736616, -0.73736616, 3.29993168, -3.29993168, 2.56256552, -2.56256552],
[-3.95416573, -3.95416573, -1.39160021, -1.39160021, -0.65423405, -0.65423405],
[-4.69153189, -3.21679958, -4.69153189, 1.90833147, -3.21679958, 1.90833147],
[-1.39160021, -0.65423405, -3.95416573, -0.65423405, -3.95416573, -1.39160021],
[0., 0., 0., 0., 0., 0.],
[3.95416573, 3.95416573, 1.39160021, 1.39160021, 0.65423405, 0.65423405],
[1.39160021, 0.65423405, 3.95416573, 0.65423405, 3.95416573, 1.39160021],
[0., 0., 0., 0., 0., 0.],
])

D20 = numpy.array([
[0.65423405, 1.39160021, 0.65423405, 3.95416573, 1.39160021, 3.95416573],
[2.64569763, 1.17096531, 0.0831321, -6.51673126, -2.12896637, -7.25409741],
[1.39160021, 0.65423405, 3.95416573, 0.65423405, 3.95416573, 1.39160021],
[-7.25409741, -6.51673126, -2.12896637, 1.17096531, 0.0831321, 2.64569763],
[-2.78320042, -1.30846811, -7.90833147, -1.30846811, -7.90833147, -2.78320042],
[0., 0., 0., 0., 0., 0.],
[3.95416573, 3.95416573, 1.39160021, 1.39160021, 0.65423405, 0.65423405],
[1.39160021, 0.65423405, 3.95416573, 0.65423405, 3.95416573, 1.39160021],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
])


def test_bernstein_2nd_derivatives():
ref_el = ufc_simplex(2)
degree = 3

elem = Bernstein(ref_el, degree)
rule = create_quadrature(ref_el, degree)
points = rule.get_points()

actual = elem.tabulate(2, points)

assert numpy.allclose(D02, actual[(0, 2)])
assert numpy.allclose(D11, actual[(1, 1)])
assert numpy.allclose(D20, actual[(2, 0)])


if __name__ == '__main__':
import os
pytest.main(os.path.abspath(__file__))

0 comments on commit f2a3059

Please sign in to comment.