Skip to content

Commit

Permalink
Merged in michal/ufl-constant (pull request #116)
Browse files Browse the repository at this point in the history
Replace ufl.Constant

Approved-by: Chris Richardson <chris@bpi.cam.ac.uk>
Approved-by: Garth N. Wells <gnw20@cam.ac.uk>
  • Loading branch information
michalhabera authored and garth-wells committed Sep 2, 2019
2 parents addca65 + 5940656 commit 09cdffb
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 41 deletions.
5 changes: 2 additions & 3 deletions doc/sphinx/source/manual/form_language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,8 @@ There is a shorthand for this, whose use is similar to ``Arguments``, called

u, p = Coefficients(TH)

Spatially constant (or discontinuous piecewise constant) functions can
conveniently be represented by ``Constant``, ``VectorConstant``, and
``TensorConstant``::
Spatially constant values can conveniently be represented by
``Constant``, ``VectorConstant``, and ``TensorConstant``::

c0 = Constant(cell)
v0 = VectorConstant(cell)
Expand Down
3 changes: 0 additions & 3 deletions test/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,18 +603,15 @@ def Nw(x, derivatives):
assert derivatives == (0,)
return dw

w, b, K = form_data_f.original_form.coefficients()
mapping = {K: Kv, b: bv, w: Nw}
fv2 = f_expression((0,), mapping)
self.assertAlmostEqual(fv, fv2)

w, b, K = form_data_F.original_form.coefficients()
v, = form_data_F.original_form.arguments()
mapping = {K: Kv, b: bv, v: Nv, w: Nw}
Fv2 = F_expression((0,), mapping)
self.assertAlmostEqual(Fv, Fv2)

w, b, K = form_data_J.original_form.coefficients()
v, u = form_data_J.original_form.arguments()
mapping = {K: Kv, b: bv, v: Nv, u: Nu, w: Nw}
Jv2 = J_expression((0,), mapping)
Expand Down
5 changes: 4 additions & 1 deletion test/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def df(v):


def testCoefficient():
v = Constant(triangle)
coord_elem = VectorElement("P", triangle, 1, dim=3)
mesh = Mesh(coord_elem)
V = FunctionSpace(mesh, FiniteElement("P", triangle, 1))
v = Coefficient(V)
assert round(expand_derivatives(diff(v, v))-1.0, 7) == 0


Expand Down
4 changes: 2 additions & 2 deletions ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@
Arguments, TestFunctions, TrialFunctions

# Coefficients
from ufl.coefficient import Coefficient, Coefficients, \
Constant, VectorConstant, TensorConstant
from ufl.coefficient import Coefficient, Coefficients
from ufl.constant import Constant, VectorConstant, TensorConstant

# Split function
from ufl.split_functions import split
Expand Down
6 changes: 6 additions & 0 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ufl.core.terminal import Terminal, FormArgument
from ufl.argument import Argument
from ufl.coefficient import Coefficient
from ufl.constant import Constant
from ufl.algorithms.traversal import iter_expressions
from ufl.corealg.traversal import unique_pre_traversal, traverse_unique_terminals

Expand Down Expand Up @@ -110,6 +111,11 @@ def extract_coefficients(a):
return sorted_by_count(extract_type(a, Coefficient))


def extract_constants(a):
"""Build a sorted list of all constants in a"""
return sorted_by_count(extract_type(a, Constant))


def extract_arguments_and_coefficients(a):
"""Build two sorted lists of all arguments and coefficients
in a, which can be a Form, Integral or Expr."""
Expand Down
3 changes: 3 additions & 0 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def independent_operator(self, o):
# Literals are by definition independent of any differentiation variable
constant_value = independent_terminal

# Constants are independent of any differentiation
constant = independent_terminal

# Rules for form arguments must be specified in specialized rule set
form_argument = override

Expand Down
3 changes: 3 additions & 0 deletions ufl/algorithms/estimate_degrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def constant_value(self, v):
"Constant values are constant."
return 0

def constant(self, v):
return 0

def geometric_quantity(self, v):
"Some geometric quantities are cellwise constant. Others are nonpolynomial and thus hard to estimate."
if is_cellwise_constant(v):
Expand Down
3 changes: 2 additions & 1 deletion ufl/algorithms/formfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ufl.form import Form
from ufl.finiteelement import FiniteElementBase
from ufl.core.expr import Expr
from ufl.constant import Constant
from ufl.argument import Argument
from ufl.coefficient import Coefficient

Expand Down Expand Up @@ -148,7 +149,7 @@ def interpret_ufl_namespace(namespace):
# FIXME: Remove after FFC is updated to use reserved_objects:
ufd.object_names[name] = value
ufd.object_by_name[name] = value
elif isinstance(value, (FiniteElementBase, Coefficient, Argument, Form, Expr)):
elif isinstance(value, (FiniteElementBase, Coefficient, Constant, Argument, Form, Expr)):
# Store instance <-> name mappings for important objects
# without a reserved name
ufd.object_names[id(value)] = name
Expand Down
5 changes: 4 additions & 1 deletion ufl/algorithms/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ufl.classes import (Label,
Index, MultiIndex,
Coefficient, Argument,
GeometricQuantity, ConstantValue,
GeometricQuantity, ConstantValue, Constant,
ExprList, ExprMapping)
from ufl.log import error
from ufl.corealg.traversal import traverse_unique_terminals, pre_traversal
Expand Down Expand Up @@ -71,6 +71,9 @@ def compute_terminal_hashdata(expressions, renumbering):
elif isinstance(expr, Coefficient):
data = expr._ufl_signature_data_(renumbering)

elif isinstance(expr, Constant):
data = expr._ufl_signature_data_(renumbering)

elif isinstance(expr, Argument):
data = expr._ufl_signature_data_(renumbering)

Expand Down
32 changes: 2 additions & 30 deletions ufl/coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from ufl.log import error
from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import FormArgument
from ufl.finiteelement import FiniteElementBase, FiniteElement, VectorElement, TensorElement
from ufl.domain import as_domain, default_domain
from ufl.finiteelement import FiniteElementBase
from ufl.domain import default_domain
from ufl.functionspace import AbstractFunctionSpace, FunctionSpace
from ufl.split_functions import split
from ufl.utils.counted import counted_init
Expand Down Expand Up @@ -117,34 +117,6 @@ def __eq__(self, other):
self._ufl_function_space == other._ufl_function_space)


# --- Helper functions for defining constant coefficients without
# --- specifying element ---

def Constant(domain, count=None):
"""UFL value: Represents a globally constant scalar valued coefficient."""
domain = as_domain(domain)
element = FiniteElement("Real", domain.ufl_cell(), 0)
fs = FunctionSpace(domain, element)
return Coefficient(fs, count=count)


def VectorConstant(domain, dim=None, count=None):
"""UFL value: Represents a globally constant vector valued coefficient."""
domain = as_domain(domain)
element = VectorElement("Real", domain.ufl_cell(), 0, dim)
fs = FunctionSpace(domain, element)
return Coefficient(fs, count=count)


def TensorConstant(domain, shape=None, symmetry=None, count=None):
"""UFL value: Represents a globally constant tensor valued coefficient."""
domain = as_domain(domain)
element = TensorElement("Real", domain.ufl_cell(), 0, shape=shape,
symmetry=symmetry)
fs = FunctionSpace(domain, element)
return Coefficient(fs, count=count)


# --- Helper functions for subfunctions on mixed elements ---

def Coefficients(function_space):
Expand Down
88 changes: 88 additions & 0 deletions ufl/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""This module defines classes representing non-literal values
which are constant with respect to a domain."""

# Copyright (C) 2019 Michal Habera
#
# This file is part of UFL.
#
# UFL is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# UFL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with UFL. If not, see <http://www.gnu.org/licenses/>.

from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import Terminal
from ufl.domain import as_domain
from ufl.utils.counted import counted_init


@ufl_type()
class Constant(Terminal):
_ufl_noslots_ = True
_globalcount = 0

def __init__(self, domain, shape=(), count=None):
Terminal.__init__(self)
counted_init(self, count=count, countedclass=Constant)

self._ufl_domain = as_domain(domain)
self._ufl_shape = shape

# Repr string is build in such way, that reconstruction
# with eval() is possible
self._repr = "Constant({}, {}, {})".format(
repr(self._ufl_domain), repr(self._ufl_shape), repr(self._count))

def count(self):
return self._count

@property
def ufl_shape(self):
return self._ufl_shape

def ufl_domain(self):
return self._ufl_domain

def ufl_domains(self):
return (self.ufl_domain(), )

def is_cellwise_constant(self):
return True

def __str__(self):
count = str(self._count)
if len(count) == 1:
return "c_%s" % count
else:
return "c_{%s}" % count

def __repr__(self):
return self._repr

def __eq__(self, other):
if not isinstance(other, Constant):
return False
if self is other:
return True
return (self._count == other._count and
self._ufl_domain == other._ufl_domain and
self._ufl_shape == self._ufl_shape)


def VectorConstant(domain, count=None):
domain = as_domain(domain)
return Constant(domain, shape=(domain.geometric_dimension(), ), count=count)


def TensorConstant(domain, count=None):
domain = as_domain(domain)
return Constant(domain, shape=(domain.geometric_dimension(), domain.geometric_dimension()), count=count)
1 change: 1 addition & 0 deletions ufl/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def find_geometric_dimension(expr):
cell = element.cell()
if cell is not None:
gdims.add(cell.geometric_dimension())

if len(gdims) != 1:
error("Cannot determine geometric dimension from expression.")
gdim, = gdims
Expand Down
7 changes: 7 additions & 0 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Form(object):
"_arguments",
"_coefficients",
"_coefficient_numbering",
"_constants",
"_hash",
"_signature",
# --- Dict that external frameworks can place framework-specific
Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(self, integrals):
self._coefficients = None
self._coefficient_numbering = None

from ufl.algorithms.analysis import extract_constants
self._constants = extract_constants(self)

# Internal variables for caching of hash and signature after
# first request
self._hash = None
Expand Down Expand Up @@ -235,6 +239,9 @@ def coefficient_numbering(self):
self._analyze_form_arguments()
return self._coefficient_numbering

def constants(self):
return self._constants

def signature(self):
"Signature for use with jit cache (independent of incidental numbering of indices etc.)"
if self._signature is None:
Expand Down
9 changes: 9 additions & 0 deletions ufl/formatting/ufl2unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,15 @@ def coefficient(self, o):
return "%s%s" % (var, superscript_number(i))
return self.coefficient_names[o.count()]

def constant(self, o):
i = o.count()
var = "c"
if len(o.ufl_shape) == 1:
var += UC.combining_right_arrow_above
elif len(o.ufl_shape) > 1 and self.colorama_bold:
var = "%s%s%s" % (colorama.Style.BRIGHT, var, colorama.Style.RESET_ALL)
return "%s%s" % (var, superscript_number(i))

def multi_index(self, o):
return ",".join(format_index(i) for i in o)

Expand Down

0 comments on commit 09cdffb

Please sign in to comment.