From 7266bc14c80094391bd627909c792f886ad139ae Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 11:51:36 +0100 Subject: [PATCH 01/13] Move Cnodes to Lnodes --- ffcx/codegeneration/{C/cnodes.py => lnodes.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ffcx/codegeneration/{C/cnodes.py => lnodes.py} (100%) diff --git a/ffcx/codegeneration/C/cnodes.py b/ffcx/codegeneration/lnodes.py similarity index 100% rename from ffcx/codegeneration/C/cnodes.py rename to ffcx/codegeneration/lnodes.py From 8d3bfc6a6b9d8dbb8f6e45fba5f41e126f424c9d Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 11:52:49 +0100 Subject: [PATCH 02/13] Use new LNodes --- ffcx/codegeneration/lnodes.py | 1319 ++++++++------------------------- 1 file changed, 326 insertions(+), 993 deletions(-) diff --git a/ffcx/codegeneration/lnodes.py b/ffcx/codegeneration/lnodes.py index c648d3284..8aba3029e 100644 --- a/ffcx/codegeneration/lnodes.py +++ b/ffcx/codegeneration/lnodes.py @@ -1,78 +1,122 @@ -# Copyright (C) 2013-2017 Martin Sandve Alnæs +# Copyright (C) 2013-2023 Martin Sandve Alnæs, Chris Richardson # # This file is part of FFCx.(https://www.fenicsproject.org) # # SPDX-License-Identifier: LGPL-3.0-or-later -import logging import numbers - +import ufl import numpy as np - -from ffcx.codegeneration.C.format_lines import Indented, format_indented_lines -from ffcx.codegeneration.C.format_value import format_float, format_int, format_value -from ffcx.codegeneration.C.precedence import PRECEDENCE - -logger = logging.getLogger("ffcx") -"""CNode TODO: -- Array copy statement -- Extend ArrayDecl and ArrayAccess with support for - flattened but conceptually multidimensional arrays, - maybe even with padding (FlattenedArray possibly covers what we need) -- Function declaration -- TypeDef -- Type -- TemplateArgumentList -- Class declaration -- Class definition +from enum import Enum + + +class PRECEDENCE: + """An enum-like class for operator precedence levels.""" + + HIGHEST = 0 + LITERAL = 0 + SYMBOL = 0 + SUBSCRIPT = 2 + + NOT = 3 + NEG = 3 + + MUL = 4 + DIV = 4 + + ADD = 5 + SUB = 5 + + LT = 7 + LE = 7 + GT = 7 + GE = 7 + EQ = 8 + NE = 8 + AND = 11 + OR = 12 + CONDITIONAL = 13 + ASSIGN = 13 + LOWEST = 15 + + +"""LNodes is intended as a minimal generic language description. +Formatting is done later, depending on the target language. + +Supported: + Floating point (and complex) and integer variables and multidimensional arrays + Range loops + Simple arithmetic, +-*/ + Math operations + Logic conditions + Comments +Not supported: + Pointers + Function Calls + Flow control (if, switch, while) + Booleans + Strings """ -# Some helper functions - -def is_zero_cexpr(cexpr): - return (isinstance(cexpr, LiteralFloat) and cexpr.value == 0.0) or ( - isinstance(cexpr, LiteralInt) and cexpr.value == 0 +def is_zero_lexpr(lexpr): + return (isinstance(lexpr, LiteralFloat) and lexpr.value == 0.0) or ( + isinstance(lexpr, LiteralInt) and lexpr.value == 0 ) -def is_one_cexpr(cexpr): - return (isinstance(cexpr, LiteralFloat) and cexpr.value == 1.0) or ( - isinstance(cexpr, LiteralInt) and cexpr.value == 1 +def is_one_lexpr(lexpr): + return (isinstance(lexpr, LiteralFloat) and lexpr.value == 1.0) or ( + isinstance(lexpr, LiteralInt) and lexpr.value == 1 ) -def is_negative_one_cexpr(cexpr): - return (isinstance(cexpr, LiteralFloat) and cexpr.value == -1.0) or ( - isinstance(cexpr, LiteralInt) and cexpr.value == -1 +def is_negative_one_lexpr(lexpr): + return (isinstance(lexpr, LiteralFloat) and lexpr.value == -1.0) or ( + isinstance(lexpr, LiteralInt) and lexpr.value == -1 ) def float_product(factors): """Build product of float factors, simplifying ones and zeros and returning 1.0 if empty sequence.""" - factors = [f for f in factors if not is_one_cexpr(f)] + factors = [f for f in factors if not is_one_lexpr(f)] if len(factors) == 0: return LiteralFloat(1.0) elif len(factors) == 1: return factors[0] else: for f in factors: - if is_zero_cexpr(f): + if is_zero_lexpr(f): return f return Product(factors) -# CNode core +class DataType(Enum): + """Representation of data types for variables in LNodes. + These can be REAL (same type as geometry), + SCALAR (same type as tensor), or INT (for entity indices etc.) + """ -class CNode(object): - """Base class for all C AST nodes.""" + REAL = 0 + SCALAR = 1 + INT = 2 - __slots__ = () - def __str__(self): - name = self.__class__.__name__ - raise NotImplementedError("Missing implementation of __str__ in " + name) +def merge_dtypes(dtype0, dtype1): + # Promote dtype to SCALAR or REAL if either argument matches + if DataType.SCALAR in (dtype0, dtype1): + return DataType.SCALAR + elif DataType.REAL in (dtype0, dtype1): + return DataType.REAL + elif (dtype0 == DataType.INT and dtype1 == DataType.INT): + return DataType.INT + else: + raise ValueError(f"Can't get dtype for binary operation with {dtype0, dtype1}") + + +class LNode(object): + """Base class for all AST nodes.""" def __eq__(self, other): name = self.__class__.__name__ @@ -82,29 +126,12 @@ def __ne__(self, other): return not self.__eq__(other) -# CExpr base classes - - -class CExpr(CNode): - """Base class for all C expressions. +class LExpr(LNode): + """Base class for all expressions. All subtypes should define a 'precedence' class attribute. - """ - __slots__ = () - - def ce_format(self, precision=None): - raise NotImplementedError("Missing implementation of ce_format() in CExpr.") - - def __str__(self): - try: - s = self.ce_format() - except Exception: - raise - - return s - def __getitem__(self, indices): return ArrayAccess(self, indices) @@ -116,346 +143,187 @@ def __neg__(self): return Neg(self) def __add__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return other - if is_zero_cexpr(other): + if is_zero_lexpr(other): return self if isinstance(other, Neg): return Sub(self, other.arg) return Add(self, other) def __radd__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return other - if is_zero_cexpr(other): + if is_zero_lexpr(other): return self if isinstance(self, Neg): return Sub(other, self.arg) return Add(other, self) def __sub__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return -other - if is_zero_cexpr(other): + if is_zero_lexpr(other): return self if isinstance(other, Neg): return Add(self, other.arg) + if isinstance(self, LiteralInt) and isinstance(other, LiteralInt): + return LiteralInt(self.value - other.value) return Sub(self, other) def __rsub__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return other - if is_zero_cexpr(other): + if is_zero_lexpr(other): return -self if isinstance(self, Neg): return Add(other, self.arg) return Sub(other, self) def __mul__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return self - if is_zero_cexpr(other): + if is_zero_lexpr(other): return other - if is_one_cexpr(self): + if is_one_lexpr(self): return other - if is_one_cexpr(other): + if is_one_lexpr(other): return self - if is_negative_one_cexpr(other): + if is_negative_one_lexpr(other): return Neg(self) - if is_negative_one_cexpr(self): + if is_negative_one_lexpr(self): return Neg(other) + if isinstance(self, LiteralInt) and isinstance(other, LiteralInt): + return LiteralInt(self.value * other.value) return Mul(self, other) def __rmul__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): return self - if is_zero_cexpr(other): + if is_zero_lexpr(other): return other - if is_one_cexpr(self): + if is_one_lexpr(self): return other - if is_one_cexpr(other): + if is_one_lexpr(other): return self - if is_negative_one_cexpr(other): + if is_negative_one_lexpr(other): return Neg(self) - if is_negative_one_cexpr(self): + if is_negative_one_lexpr(self): return Neg(other) return Mul(other, self) def __div__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(other): + other = as_lexpr(other) + if is_zero_lexpr(other): raise ValueError("Division by zero!") - if is_zero_cexpr(self): + if is_zero_lexpr(self): return self return Div(self, other) def __rdiv__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): + other = as_lexpr(other) + if is_zero_lexpr(self): raise ValueError("Division by zero!") - if is_zero_cexpr(other): + if is_zero_lexpr(other): return other return Div(other, self) - # TODO: Error check types? Can't do that exactly as symbols here have no type. + # TODO: Error check types? __truediv__ = __div__ __rtruediv__ = __rdiv__ __floordiv__ = __div__ __rfloordiv__ = __rdiv__ - def __mod__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(other): - raise ValueError("Division by zero!") - if is_zero_cexpr(self): - return self - return Mod(self, other) - - def __rmod__(self, other): - other = as_cexpr(other) - if is_zero_cexpr(self): - raise ValueError("Division by zero!") - if is_zero_cexpr(other): - return other - return Mod(other, self) - -class CExprOperator(CExpr): - """Base class for all C expression operator.""" +class LExprOperator(LExpr): + """Base class for all expression operators.""" - __slots__ = () sideeffect = False -class CExprTerminal(CExpr): - """Base class for all C expression terminals.""" +class LExprTerminal(LExpr): + """Base class for all expression terminals.""" - __slots__ = () sideeffect = False -# CExprTerminal types - - -class CExprLiteral(CExprTerminal): - """A float or int literal value.""" - - __slots__ = () - precedence = PRECEDENCE.LITERAL - - -class Null(CExprLiteral): - """A null pointer literal.""" - - __slots__ = () - precedence = PRECEDENCE.LITERAL +# LExprTerminal types - def ce_format(self, precision=None): - return "NULL" - def __eq__(self, other): - return isinstance(other, Null) - - -class LiteralFloat(CExprLiteral): +class LiteralFloat(LExprTerminal): """A floating point literal value.""" - __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): - assert isinstance(value, (float, complex, int, np.number)) + assert isinstance(value, (float, complex)) self.value = value - - def ce_format(self, precision=None): - return format_float(self.value, precision) + if isinstance(value, complex): + self.dtype = DataType.SCALAR + else: + self.dtype = DataType.REAL def __eq__(self, other): return isinstance(other, LiteralFloat) and self.value == other.value - def __bool__(self): - return bool(self.value) - - __nonzero__ = __bool__ - def __float__(self): return float(self.value) - def flops(self): - return 0 - -class LiteralInt(CExprLiteral): +class LiteralInt(LExprTerminal): """An integer literal value.""" - __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): assert isinstance(value, (int, np.number)) self.value = value - - def ce_format(self, precision=None): - return str(self.value) - - def flops(self): - return 0 + self.dtype = DataType.INT def __eq__(self, other): return isinstance(other, LiteralInt) and self.value == other.value - def __bool__(self): - return bool(self.value) - - __nonzero__ = __bool__ - - def __int__(self): - return int(self.value) - - def __float__(self): - return float(self.value) - def __hash__(self): - return hash(self.ce_format()) - - -class LiteralBool(CExprLiteral): - """A boolean literal value.""" - - __slots__ = ("value",) - precedence = PRECEDENCE.LITERAL - - def __init__(self, value): - assert isinstance(value, (bool,)) - self.value = value - - def ce_format(self, precision=None): - return "true" if self.value else "false" - - def __eq__(self, other): - return isinstance(other, LiteralBool) and self.value == other.value - - def __bool__(self): - return bool(self.value) - - __nonzero__ = __bool__ + return hash(self.value) -class LiteralString(CExprLiteral): - """A boolean literal value.""" - - __slots__ = ("value",) - precedence = PRECEDENCE.LITERAL - - def __init__(self, value): - assert isinstance(value, (str,)) - assert '"' not in value - self.value = value - - def ce_format(self, precision=None): - return '"%s"' % (self.value,) - - def __eq__(self, other): - return isinstance(other, LiteralString) and self.value == other.value - - -class Symbol(CExprTerminal): +class Symbol(LExprTerminal): """A named symbol.""" - __slots__ = ("name",) precedence = PRECEDENCE.SYMBOL - def __init__(self, name): + def __init__(self, name, dtype=None): assert isinstance(name, str) self.name = name - - def ce_format(self, precision=None): - return self.name - - def flops(self): - return 0 + self.dtype = dtype def __eq__(self, other): return isinstance(other, Symbol) and self.name == other.name def __hash__(self): - return hash(self.ce_format()) - + return hash(self.name) -# CExprOperator base classes - -class UnaryOp(CExprOperator): +class PrefixUnaryOp(LExprOperator): """Base class for unary operators.""" - __slots__ = ("arg",) - def __init__(self, arg): - self.arg = as_cexpr(arg) + self.arg = as_lexpr(arg) def __eq__(self, other): return isinstance(other, type(self)) and self.arg == other.arg - def flops(self): - raise NotImplementedError() - - -class PrefixUnaryOp(UnaryOp): - """Base class for prefix unary operators.""" - - __slots__ = () - - def ce_format(self, precision=None): - arg = self.arg.ce_format(precision) - if self.arg.precedence >= self.precedence: - arg = "(" + arg + ")" - return self.op + arg - - def __eq__(self, other): - return isinstance(other, type(self)) - - -class PostfixUnaryOp(UnaryOp): - """Base class for postfix unary operators.""" - - __slots__ = () - - def ce_format(self, precision=None): - arg = self.arg.ce_format(precision) - if self.arg.precedence >= self.precedence: - arg = "(" + arg + ")" - return arg + self.op - - def __eq__(self, other): - return isinstance(other, type(self)) - - -class BinOp(CExprOperator): - __slots__ = ("lhs", "rhs") +class BinOp(LExprOperator): def __init__(self, lhs, rhs): - self.lhs = as_cexpr(lhs) - self.rhs = as_cexpr(rhs) - - def ce_format(self, precision=None): - # Format children - lhs = self.lhs.ce_format(precision) - rhs = self.rhs.ce_format(precision) - - # Apply parentheses - if self.lhs.precedence >= self.precedence: - lhs = "(" + lhs + ")" - if self.rhs.precedence >= self.precedence: - rhs = "(" + rhs + ")" - - # Return combined string - return lhs + (" " + self.op + " ") + rhs + self.lhs = as_lexpr(lhs) + self.rhs = as_lexpr(rhs) def __eq__(self, other): return ( @@ -465,35 +333,21 @@ def __eq__(self, other): ) def __hash__(self): - return hash(self.ce_format()) + return hash(self.lhs) + hash(self.rhs) - def flops(self): - return 1 + self.lhs.flops() + self.rhs.flops() +class ArithmeticBinOp(BinOp): + def __init__(self, lhs, rhs): + self.lhs = as_lexpr(lhs) + self.rhs = as_lexpr(rhs) + self.dtype = merge_dtypes(self.lhs.dtype, self.rhs.dtype) -class NaryOp(CExprOperator): - """Base class for special n-ary operators.""" - __slots__ = ("args",) +class NaryOp(LExprOperator): + """Base class for special n-ary operators.""" def __init__(self, args): - self.args = [as_cexpr(arg) for arg in args] - - def ce_format(self, precision=None): - # Format children - args = [arg.ce_format(precision) for arg in self.args] - - # Apply parentheses - for i in range(len(args)): - if self.args[i].precedence >= self.precedence: - args[i] = "(" + args[i] + ")" - - # Return combined string - op = " " + self.op + " " - s = args[0] - for i in range(1, len(args)): - s += op + args[i] - return s + self.args = [as_lexpr(arg) for arg in args] def __eq__(self, other): return ( @@ -502,175 +356,81 @@ def __eq__(self, other): and all(a == b for a, b in zip(self.args, other.args)) ) - def flops(self): - flops = len(self.args) - 1 - for arg in self.args: - flops += arg.flops() - return flops - - -# CExpr unary operators - - -class AddressOf(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.ADDRESSOF - op = "&" - - -class SizeOf(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.SIZEOF - op = "sizeof" - class Neg(PrefixUnaryOp): - __slots__ = () precedence = PRECEDENCE.NEG op = "-" - -class Pos(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.POS - op = "+" + def __init__(self, arg): + self.arg = as_lexpr(arg) + self.dtype = self.arg.dtype class Not(PrefixUnaryOp): - __slots__ = () precedence = PRECEDENCE.NOT op = "!" -class BitNot(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.BIT_NOT - op = "~" - - -class PreIncrement(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.PRE_INC - sideeffect = True - op = "++" - - -class PreDecrement(PrefixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.PRE_DEC - sideeffect = True - op = "--" - - -class PostIncrement(PostfixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.POST_INC - sideeffect = True - op = "++" +# Binary operators +# Arithmetic operators preserve the dtype of their operands +# The other operations (logical) do not need a dtype - -class PostDecrement(PostfixUnaryOp): - __slots__ = () - precedence = PRECEDENCE.POST_DEC - sideeffect = True - op = "--" - - -# CExpr binary operators - - -class Add(BinOp): - __slots__ = () +class Add(ArithmeticBinOp): precedence = PRECEDENCE.ADD op = "+" -class Sub(BinOp): - __slots__ = () +class Sub(ArithmeticBinOp): precedence = PRECEDENCE.SUB op = "-" -class Mul(BinOp): - __slots__ = () +class Mul(ArithmeticBinOp): precedence = PRECEDENCE.MUL op = "*" -class Div(BinOp): - __slots__ = () +class Div(ArithmeticBinOp): precedence = PRECEDENCE.DIV op = "/" -class Mod(BinOp): - __slots__ = () - precedence = PRECEDENCE.MOD - op = "%" - - class EQ(BinOp): - __slots__ = () precedence = PRECEDENCE.EQ op = "==" class NE(BinOp): - __slots__ = () precedence = PRECEDENCE.NE op = "!=" class LT(BinOp): - __slots__ = () precedence = PRECEDENCE.LT op = "<" class GT(BinOp): - __slots__ = () precedence = PRECEDENCE.GT op = ">" class LE(BinOp): - __slots__ = () precedence = PRECEDENCE.LE op = "<=" class GE(BinOp): - __slots__ = () precedence = PRECEDENCE.GE op = ">=" -class BitwiseAnd(BinOp): - __slots__ = () - precedence = PRECEDENCE.BITAND - op = "&" - - -class BitShiftR(BinOp): - __slots__ = () - precedence = PRECEDENCE.BITSHIFT - op = ">>" - - -class BitShiftL(BinOp): - __slots__ = () - precedence = PRECEDENCE.BITSHIFT - op = "<<" - - class And(BinOp): - __slots__ = () precedence = PRECEDENCE.AND op = "&&" class Or(BinOp): - __slots__ = () precedence = PRECEDENCE.OR op = "||" @@ -678,7 +438,6 @@ class Or(BinOp): class Sum(NaryOp): """Sum of any number of operands.""" - __slots__ = () precedence = PRECEDENCE.ADD op = "+" @@ -686,94 +445,86 @@ class Sum(NaryOp): class Product(NaryOp): """Product of any number of operands.""" - __slots__ = () precedence = PRECEDENCE.MUL op = "*" +class MathFunction(LExprOperator): + """A Math Function, with any arguments.""" + + precedence = PRECEDENCE.HIGHEST + + def __init__(self, func, args): + self.function = func + self.args = [as_lexpr(arg) for arg in args] + self.dtype = self.args[0].dtype + + def __eq__(self, other): + return ( + isinstance(other, type(self)) + and self.function == other.function + and len(self.args) == len(other.args) + and all(a == b for a, b in zip(self.args, other.args)) + ) + + class AssignOp(BinOp): """Base class for assignment operators.""" - __slots__ = () precedence = PRECEDENCE.ASSIGN sideeffect = True def __init__(self, lhs, rhs): - BinOp.__init__(self, as_cexpr_or_string_symbol(lhs), rhs) + assert isinstance(lhs, LNode) + BinOp.__init__(self, lhs, rhs) class Assign(AssignOp): - __slots__ = () op = "=" - def flops(self): - return super().flops() - 1 - class AssignAdd(AssignOp): - __slots__ = () op = "+=" class AssignSub(AssignOp): - __slots__ = () op = "-=" class AssignMul(AssignOp): - __slots__ = () op = "*=" class AssignDiv(AssignOp): - __slots__ = () op = "/=" -# CExpr operators - - class FlattenedArray(object): """Syntax carrying object only, will get translated on __getitem__ to ArrayAccess.""" - __slots__ = ("array", "strides", "offset", "dims") - - def __init__(self, array, dummy=None, dims=None, strides=None, offset=None): - assert dummy is None, "Please use keyword arguments for strides or dims." - - # Typecheck array argument - if isinstance(array, ArrayDecl): - self.array = array.symbol - elif isinstance(array, Symbol): - self.array = array - else: - assert isinstance(array, str) - self.array = Symbol(array) + def __init__(self, array, dims=None): + assert dims is not None + assert isinstance(array, Symbol) + self.array = array # Allow expressions or literals as strides or dims and offset - if strides is None: - assert dims is not None, "Please provide either strides or dims." - assert isinstance(dims, (list, tuple)) - dims = tuple(as_cexpr(i) for i in dims) - self.dims = dims - n = len(dims) - literal_one = LiteralInt(1) - strides = [literal_one] * n - for i in range(n - 2, -1, -1): - s = strides[i + 1] - d = dims[i + 1] - if d == literal_one: - strides[i] = s - elif s == literal_one: - strides[i] = d - else: - strides[i] = d * s - else: - self.dims = None - assert isinstance(strides, (list, tuple)) - strides = tuple(as_cexpr(i) for i in strides) + assert isinstance(dims, (list, tuple)) + dims = tuple(as_lexpr(i) for i in dims) + self.dims = dims + n = len(dims) + literal_one = LiteralInt(1) + strides = [literal_one] * n + for i in range(n - 2, -1, -1): + s = strides[i + 1] + d = dims[i + 1] + if d == literal_one: + strides[i] = s + elif s == literal_one: + strides[i] = d + else: + strides[i] = d * s + self.strides = strides - self.offset = None if offset is None else as_cexpr(offset) def __getitem__(self, indices): if not isinstance(indices, (list, tuple)): @@ -788,10 +539,8 @@ def __getitem__(self, indices): i, s = (indices[0], self.strides[0]) literal_one = LiteralInt(1) flat = i if s == literal_one else s * i - if self.offset is not None: - flat = self.offset + flat for i, s in zip(indices[1:n], self.strides[1:n]): - flat = flat + (i if s == literal_one else s * i) + flat = flat + s * i # Delay applying ArrayAccess until we have all indices if n == len(self.strides): return ArrayAccess(self.array, flat) @@ -799,23 +548,24 @@ def __getitem__(self, indices): return FlattenedArray(self.array, strides=self.strides[n:], offset=flat) -class ArrayAccess(CExprOperator): - __slots__ = ("array", "indices") +class ArrayAccess(LExprOperator): precedence = PRECEDENCE.SUBSCRIPT def __init__(self, array, indices): # Typecheck array argument if isinstance(array, Symbol): self.array = array + self.dtype = array.dtype elif isinstance(array, ArrayDecl): self.array = array.symbol + self.dtype = array.symbol.dtype else: raise ValueError("Unexpected array type %s." % (type(array).__name__,)) # Allow expressions or literals as indices if not isinstance(indices, (list, tuple)): indices = (indices,) - self.indices = tuple(as_cexpr_or_string_symbol(i) for i in indices) + self.indices = tuple(as_lexpr(i) for i in indices) # Early error checking for negative array dimensions if any(isinstance(i, int) and i < 0 for i in self.indices): @@ -840,12 +590,6 @@ def __getitem__(self, indices): indices = (indices,) return ArrayAccess(self.array, self.indices + indices) - def ce_format(self, precision=None): - s = self.array.ce_format(precision) - for index in self.indices: - s += "[" + index.ce_format(precision) + "]" - return s - def __eq__(self, other): return ( isinstance(other, type(self)) @@ -854,37 +598,17 @@ def __eq__(self, other): ) def __hash__(self): - return hash(self.ce_format()) + return hash(self.array) - def flops(self): - return 0 - -class Conditional(CExprOperator): - __slots__ = ("condition", "true", "false") +class Conditional(LExprOperator): precedence = PRECEDENCE.CONDITIONAL def __init__(self, condition, true, false): - self.condition = as_cexpr(condition) - self.true = as_cexpr(true) - self.false = as_cexpr(false) - - def ce_format(self, precision=None): - # Format children - c = self.condition.ce_format(precision) - t = self.true.ce_format(precision) - f = self.false.ce_format(precision) - - # Apply parentheses - if self.condition.precedence >= self.precedence: - c = "(" + c + ")" - if self.true.precedence >= self.precedence: - t = "(" + t + ")" - if self.false.precedence >= self.precedence: - f = "(" + f + ")" - - # Return combined string - return c + " ? " + t + " : " + f + self.condition = as_lexpr(condition) + self.true = as_lexpr(true) + self.false = as_lexpr(false) + self.dtype = merge_dtypes(self.true.dtype, self.false.dtype) def __eq__(self, other): return ( @@ -894,278 +618,65 @@ def __eq__(self, other): and self.false == other.false ) - def flops(self): - raise NotImplementedError("Flop count is not implemented for conditionals") - - -class Call(CExprOperator): - __slots__ = ("function", "arguments") - precedence = PRECEDENCE.CALL - sideeffect = True - - def __init__(self, function, arguments=None): - self.function = as_cexpr_or_string_symbol(function) - - # Accept None, single, or multiple arguments; literals or CExprs - if arguments is None: - arguments = () - elif not isinstance(arguments, (tuple, list)): - arguments = (arguments,) - self.arguments = [as_cexpr(arg) for arg in arguments] - - def ce_format(self, precision=None): - args = ", ".join(arg.ce_format(precision) for arg in self.arguments) - return self.function.ce_format(precision) + "(" + args + ")" - - def __eq__(self, other): - return ( - isinstance(other, type(self)) - and self.function == other.function - and self.arguments == other.arguments - ) - - def flops(self): - return 1 +def as_lexpr(node): + """Typechecks and wraps an object as a valid LExpr. -def Sqrt(x): - return Call("sqrt", x) - - -# Conversion function to expression nodes - - -def _is_zero_valued(values): - if isinstance(values, (numbers.Integral, LiteralInt)): - return int(values) == 0 - elif isinstance(values, (numbers.Number, LiteralFloat)): - return float(values) == 0.0 - else: - return np.count_nonzero(values) == 0 - - -def as_cexpr(node): - """Typechecks and wraps an object as a valid CExpr. - - Accepts CExpr nodes, treats int and float as literals, and treats a - string as a symbol. + Accepts LExpr nodes, treats int and float as literals. """ - if isinstance(node, CExpr): + if isinstance(node, LExpr): return node - elif isinstance(node, bool): - return LiteralBool(node) elif isinstance(node, numbers.Integral): return LiteralInt(node) elif isinstance(node, numbers.Real): return LiteralFloat(node) - elif isinstance(node, str): - raise RuntimeError("Got string for CExpr, this is ambiguous: %s" % (node,)) - else: - raise RuntimeError("Unexpected CExpr type %s:\n%s" % (type(node), str(node))) - - -def as_cexpr_or_string_symbol(node): - if isinstance(node, str): - return Symbol(node) - return as_cexpr(node) - - -def as_cexpr_or_literal(node): - if isinstance(node, str): - return LiteralString(node) - return as_cexpr(node) - - -def as_symbol(symbol): - if isinstance(symbol, str): - symbol = Symbol(symbol) - assert isinstance(symbol, Symbol) - return symbol - - -def flattened_indices(indices, shape): - """Return a flattened indexing expression. - - Given a tuple of indices and a shape tuple, return - a CNode expression for flattened indexing into multidimensional - array. - - Indices and shape entries can be int values, str symbol names, or - CNode expressions. - - """ - n = len(shape) - if n == 0: - # Scalar - return as_cexpr(0) - elif n == 1: - # Simple vector - return as_cexpr(indices[0]) else: - # 2d or higher - strides = [None] * (n - 2) + [shape[-1], 1] - for i in range(n - 3, -1, -1): - strides[i] = Mul(shape[i + 1], strides[i + 1]) - result = indices[-1] - for i in range(n - 2, -1, -1): - result = Add(Mul(strides[i], indices[i]), result) - return result - - -# Base class for all statements - - -class CStatement(CNode): - """Base class for all C statements. - - Subtypes do _not_ define a 'precedence' class attribute. - - """ - - __slots__ = () - - # True if statement contains its own scope, false by default to be - # on the safe side - is_scoped = False - - def cs_format(self, precision=None): - """Return S: string | list(S) | Indented(S).""" - raise NotImplementedError( - "Missing implementation of cs_format() in CStatement." - ) - - def __str__(self): - try: - s = self.cs_format() - except Exception: - logger.error("Error in CStatement string formatting.") - raise - return format_indented_lines(s) - - def flops(self): - raise NotImplementedError() - - -# Statements + raise RuntimeError("Unexpected LExpr type %s:\n%s" % (type(node), str(node))) -class VerbatimStatement(CStatement): - """Wraps a source code string to be pasted verbatim into the source code.""" - - __slots__ = ("codestring",) - is_scoped = False - - def __init__(self, codestring): - assert isinstance(codestring, str) - self.codestring = codestring - - def cs_format(self, precision=None): - return self.codestring - - def __eq__(self, other): - return isinstance(other, type(self)) and self.codestring == other.codestring - - -class Statement(CStatement): +class Statement(LNode): """Make an expression into a statement.""" - __slots__ = ("expr",) is_scoped = False def __init__(self, expr): - self.expr = as_cexpr(expr) - - def cs_format(self, precision=None): - return self.expr.ce_format(precision) + ";" + self.expr = as_lexpr(expr) def __eq__(self, other): return isinstance(other, type(self)) and self.expr == other.expr - def flops(self): - # print(self.expr.rhs.flops()) - return self.expr.flops() - -class StatementList(CStatement): +class StatementList(LNode): """A simple sequence of statements. No new scopes are introduced.""" - __slots__ = ("statements",) - def __init__(self, statements): - self.statements = [as_cstatement(st) for st in statements] + self.statements = [as_statement(st) for st in statements] @property def is_scoped(self): return all(st.is_scoped for st in self.statements) - def cs_format(self, precision=None): - return [st.cs_format(precision) for st in self.statements] - def __eq__(self, other): return isinstance(other, type(self)) and self.statements == other.statements - def flops(self): - flops = 0 - for statement in self.statements: - flops += statement.flops() - return flops - - -# Simple statements - - -class Return(CStatement): - __slots__ = ("value",) - is_scoped = True - - def __init__(self, value=None): - if value is None: - self.value = None - else: - self.value = as_cexpr(value) - def cs_format(self, precision=None): - if self.value is None: - return "return;" - else: - return "return %s;" % (self.value.ce_format(precision),) - - def __eq__(self, other): - return isinstance(other, type(self)) and self.value == other.value - - def flops(self): - return 0 - - -class Comment(CStatement): +class Comment(Statement): """Line comment(s) used for annotating the generated code with human readable remarks.""" - __slots__ = ("comment",) is_scoped = True def __init__(self, comment): assert isinstance(comment, str) self.comment = comment - def cs_format(self, precision=None): - lines = self.comment.strip().split("\n") - return ["// " + line.strip() for line in lines] - def __eq__(self, other): return isinstance(other, type(self)) and self.comment == other.comment - def flops(self): - return 0 - - -def NoOp(): - return Comment("Do nothing") - def commented_code_list(code, comments): """Add comment to code list if the list is not empty.""" - if isinstance(code, CNode): + if isinstance(code, LNode): code = [code] assert isinstance(code, list) if code: @@ -1176,54 +687,24 @@ def commented_code_list(code, comments): return code -class Pragma(CStatement): - """Pragma comments used for compiler-specific annotations.""" - - __slots__ = ("comment",) - is_scoped = True - - def __init__(self, comment): - assert isinstance(comment, str) - self.comment = comment - - def cs_format(self, precision=None): - assert "\n" not in self.comment - return "#pragma " + self.comment - - def __eq__(self, other): - return isinstance(other, type(self)) and self.comment == other.comment - - def flops(self): - return 0 - - # Type and variable declarations -class VariableDecl(CStatement): +class VariableDecl(Statement): """Declare a variable, optionally define initial value.""" - __slots__ = ("typename", "symbol", "value") is_scoped = False - def __init__(self, typename, symbol, value=None): - # No type system yet, just using strings - assert isinstance(typename, str) - self.typename = typename + def __init__(self, symbol, value=None): - # Allow Symbol or just a string - self.symbol = as_symbol(symbol) + assert isinstance(symbol, Symbol) + assert symbol.dtype is not None + self.symbol = symbol if value is not None: - value = as_cexpr(value) + value = as_lexpr(value) self.value = value - def cs_format(self, precision=None): - code = self.typename + " " + self.symbol.name - if self.value is not None: - code += " = " + self.value.ce_format(precision) - return code + ";" - def __eq__(self, other): return ( isinstance(other, type(self)) @@ -1232,111 +713,8 @@ def __eq__(self, other): and self.value == other.value ) - def flops(self): - if self.value is not None: - return self.value.flops() - else: - return 0 - - -def leftover(size, padlen): - """Return minimum integer to add to size to make it divisible by padlen.""" - return (padlen - (size % padlen)) % padlen - - -def pad_dim(dim, padlen): - """Make dim divisible by padlen.""" - return ((dim + padlen - 1) // padlen) * padlen - - -def pad_innermost_dim(shape, padlen): - """Make the last dimension in shape divisible by padlen.""" - if not shape: - return () - shape = list(shape) - if padlen: - shape[-1] = pad_dim(shape[-1], padlen) - return tuple(shape) - - -def build_1d_initializer_list(values, formatter, padlen=0, precision=None): - """Return a list containing a single line formatted like '{ 0.0, 1.0, 2.0 }'.""" - if formatter == str: - - def formatter(x, p): - return str(x) - - tokens = ["{ "] - if np.prod(values.shape) > 0: - sep = ", " - fvalues = [formatter(v, precision) for v in values] - for v in fvalues[:-1]: - tokens.append(v) - tokens.append(sep) - tokens.append(fvalues[-1]) - if padlen: - # Add padding - zero = formatter(values.dtype.type(0), precision) - for i in range(leftover(len(values), padlen)): - tokens.append(sep) - tokens.append(zero) - tokens += " }" - return "".join(tokens) - - -def build_initializer_lists(values, sizes, level, formatter, padlen=0, precision=None): - """Return a list of lines with initializer lists for a multidimensional array. - - Example output:: - { { 0.0, 0.1 }, - { 1.0, 1.1 } } - - """ - if formatter == str: - - def formatter(x, p): - return str(x) - - values = np.asarray(values) - assert np.prod(values.shape) == np.prod(sizes) - assert len(sizes) > 0 - assert len(values.shape) > 0 - assert len(sizes) == len(values.shape) - assert np.all(values.shape == sizes) - - r = len(sizes) - assert r > 0 - if r == 1: - return [ - build_1d_initializer_list( - values, formatter, padlen=padlen, precision=precision - ) - ] - else: - # Render all sublists - parts = [] - for val in values: - sublist = build_initializer_lists( - val, sizes[1:], level + 1, formatter, padlen=padlen, precision=precision - ) - parts.append(sublist) - # Add comma after last line in each part except the last one - for part in parts[:-1]: - part[-1] += "," - # Collect all lines in flat list - lines = [] - for part in parts: - lines.extend(part) - # Enclose lines in '{ ' and ' }' and indent lines in between - lines[0] = "{ " + lines[0] - for i in range(1, len(lines)): - lines[i] = " " + lines[i] - lines[-1] += " }" - return lines - - -class ArrayDecl(CStatement): +class ArrayDecl(Statement): """A declaration or definition of an array. Note that just setting values=0 is sufficient to initialize the @@ -1347,87 +725,30 @@ class ArrayDecl(CStatement): """ - __slots__ = ("typename", "symbol", "sizes", "padlen", "values") is_scoped = False - def __init__(self, typename, symbol, sizes=None, values=None, padlen=0): - assert isinstance(typename, str) - self.typename = typename - - if isinstance(symbol, FlattenedArray): - if sizes is None: - assert symbol.dims is not None - sizes = symbol.dims - elif symbol.dims is not None: - assert symbol.dims == sizes - self.symbol = symbol.array - else: - self.symbol = as_symbol(symbol) + def __init__(self, symbol, sizes=None, values=None, const=False): + assert isinstance(symbol, Symbol) + self.symbol = symbol + assert symbol.dtype + if sizes is None: + assert values is not None + sizes = values.shape if isinstance(sizes, int): sizes = (sizes,) self.sizes = tuple(sizes) - # NB! No type checking, assuming nested lists of literal values. Not applying as_cexpr. + if values is None: + assert sizes is not None + + # NB! No type checking, assuming nested lists of literal values. Not applying as_lexpr. if isinstance(values, (list, tuple)): self.values = np.asarray(values) else: self.values = values - self.padlen = padlen - - def cs_format(self, precision=None): - if not all(self.sizes): - raise RuntimeError( - f"Detected an array {self.symbol} dimension of zero. This is not valid in C." - ) - - # Pad innermost array dimension - sizes = pad_innermost_dim(self.sizes, self.padlen) - - # Add brackets - brackets = "".join("[%d]" % n for n in sizes) - - # Join declaration - decl = self.typename + " " + self.symbol.name + brackets - - if self.values is None: - # Undefined initial values - return decl + ";" - elif _is_zero_valued(self.values): - # Zero initial values - # (NB! C style zero initialization, not sure about other target languages) - nb = len(sizes) - lbr = "{" * nb - rbr = "}" * nb - return f"{decl} = {lbr} 0 {rbr};" - else: - # Construct initializer lists for arbitrary multidimensional array values - if self.values.dtype.kind == "f": - formatter = format_float - elif self.values.dtype.kind == "i": - formatter = format_int - elif self.values.dtype == np.bool_: - - def format_bool(x, precision=None): - return "true" if x is True else "false" - - formatter = format_bool - else: - formatter = format_value - initializer_lists = build_initializer_lists( - self.values, - self.sizes, - 0, - formatter, - padlen=self.padlen, - precision=precision, - ) - if len(initializer_lists) == 1: - return decl + " = " + initializer_lists[0] + ";" - else: - initializer_lists[-1] += ";" # Close statement on final line - return (decl + " =", Indented(initializer_lists)) + self.const = const def __eq__(self, other): attributes = ("typename", "symbol", "sizes", "padlen", "values") @@ -1435,29 +756,6 @@ def __eq__(self, other): getattr(self, name) == getattr(self, name) for name in attributes ) - def flops(self): - return 0 - - -# Scoped statements - - -class Scope(CStatement): - __slots__ = ("body",) - is_scoped = True - - def __init__(self, body): - self.body = as_cstatement(body) - - def cs_format(self, precision=None): - return ("{", Indented(self.body.cs_format(precision)), "}") - - def __eq__(self, other): - return isinstance(other, type(self)) and self.body == other.body - - def flops(self): - return 0 - def is_simple_inner_loop(code): if isinstance(code, ForRange) and is_simple_inner_loop(code.body): @@ -1467,81 +765,116 @@ def is_simple_inner_loop(code): return False -class ForRange(CStatement): +class ForRange(Statement): """Slightly higher-level for loop assuming incrementing an index over a range.""" - __slots__ = ("index", "begin", "end", "body", "index_type") is_scoped = True - def __init__(self, index, begin, end, body, index_type="int"): - self.index = as_cexpr_or_string_symbol(index) - self.begin = as_cexpr(begin) - self.end = as_cexpr(end) - self.body = as_cstatement(body) - self.index_type = index_type - - def cs_format(self, precision=None): - indextype = self.index_type - index = self.index.ce_format(precision) - begin = self.begin.ce_format(precision) - end = self.end.ce_format(precision) - - init = indextype + " " + index + " = " + begin - check = index + " < " + end - update = "++" + index - - prelude = "for (" + init + "; " + check + "; " + update + ")" - body = Indented(self.body.cs_format(precision)) - - # Reduce size of code with lots of simple loops by dropping {} in obviously safe cases - if is_simple_inner_loop(self.body): - code = (prelude, body) - else: - code = (prelude, "{", body, "}") - - return code + def __init__(self, index, begin, end, body): + assert isinstance(index, Symbol) + self.index = index + self.begin = as_lexpr(begin) + self.end = as_lexpr(end) + assert isinstance(body, list) + self.body = StatementList(body) def __eq__(self, other): - attributes = ("index", "begin", "end", "body", "index_type") + attributes = ("index", "begin", "end", "body") return isinstance(other, type(self)) and all( getattr(self, name) == getattr(self, name) for name in attributes ) - def flops(self): - return (self.end.value - self.begin.value) * self.body.flops() - - -# Conversion function to statement nodes - -def as_cstatement(node): +def as_statement(node): """Perform type checking on node and wrap in a suitable statement type if necessary.""" if isinstance(node, StatementList) and len(node.statements) == 1: # Cleans up the expression tree a bit return node.statements[0] - elif isinstance(node, CStatement): + elif isinstance(node, Statement): # No-op return node - elif isinstance(node, CExprOperator): + elif isinstance(node, LExprOperator): if node.sideeffect: # Special case for using assignment expressions as statements return Statement(node) else: raise RuntimeError( - "Trying to create a statement of CExprOperator type %s:\n%s" + "Trying to create a statement of lexprOperator type %s:\n%s" % (type(node), str(node)) ) elif isinstance(node, list): # Convenience case for list of statements if len(node) == 1: # Cleans up the expression tree a bit - return as_cstatement(node[0]) + return as_statement(node[0]) else: return StatementList(node) - elif isinstance(node, str): - # Backdoor for flexibility in code generation to allow verbatim pasted statements - return VerbatimStatement(node) else: raise RuntimeError( "Unexpected CStatement type %s:\n%s" % (type(node), str(node)) ) + + +class UFL2LNodes(object): + """UFL to LNodes translator class.""" + + def __init__(self): + self.force_floats = False + self.enable_strength_reduction = False + + # Lookup table for handler to call when the "get" method (below) is + # called, depending on the first argument type. + self.call_lookup = { + ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)), + ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)), + ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()), + ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0), + ufl.algebra.Product: lambda x, a, b: a * b, + ufl.algebra.Sum: lambda x, a, b: a + b, + ufl.algebra.Division: lambda x, a, b: a / b, + ufl.algebra.Abs: self.math_function, + ufl.algebra.Power: self.math_function, + ufl.algebra.Real: self.math_function, + ufl.algebra.Imag: self.math_function, + ufl.algebra.Conj: self.math_function, + ufl.classes.GT: lambda x, a, b: GT(a, b), + ufl.classes.GE: lambda x, a, b: GE(a, b), + ufl.classes.EQ: lambda x, a, b: EQ(a, b), + ufl.classes.NE: lambda x, a, b: NE(a, b), + ufl.classes.LT: lambda x, a, b: LT(a, b), + ufl.classes.LE: lambda x, a, b: LE(a, b), + ufl.classes.AndCondition: lambda x, a, b: And(a, b), + ufl.classes.OrCondition: lambda x, a, b: Or(a, b), + ufl.classes.NotCondition: lambda x, a: Not(a), + ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f), + ufl.classes.MinValue: self.math_function, + ufl.classes.MaxValue: self.math_function, + ufl.mathfunctions.Sqrt: self.math_function, + ufl.mathfunctions.Ln: self.math_function, + ufl.mathfunctions.Exp: self.math_function, + ufl.mathfunctions.Cos: self.math_function, + ufl.mathfunctions.Sin: self.math_function, + ufl.mathfunctions.Tan: self.math_function, + ufl.mathfunctions.Cosh: self.math_function, + ufl.mathfunctions.Sinh: self.math_function, + ufl.mathfunctions.Tanh: self.math_function, + ufl.mathfunctions.Acos: self.math_function, + ufl.mathfunctions.Asin: self.math_function, + ufl.mathfunctions.Atan: self.math_function, + ufl.mathfunctions.Erf: self.math_function, + ufl.mathfunctions.Atan2: self.math_function, + ufl.mathfunctions.MathFunction: self.math_function, + ufl.mathfunctions.BesselJ: self.math_function, + ufl.mathfunctions.BesselY: self.math_function, + } + + def get(self, o, *args): + # Call appropriate handler, depending on the type of o + otype = type(o) + if otype in self.call_lookup: + return self.call_lookup[otype](o, *args) + else: + raise RuntimeError(f"Missing lookup for expr type {otype}.") + + def math_function(self, o, *args): + return MathFunction(o._ufl_handler_name_, args) From 3a81fe838b34af3c8e49a960f93a0c87404d738b Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 14:40:31 +0100 Subject: [PATCH 03/13] Adjust for dtype --- ffcx/codegeneration/C/c_implementation.py | 332 ++++++++++++++++++++ ffcx/codegeneration/C/expressions.py | 6 +- ffcx/codegeneration/C/integrals.py | 5 +- ffcx/codegeneration/access.py | 19 +- ffcx/codegeneration/backend.py | 17 +- ffcx/codegeneration/definitions.py | 17 +- ffcx/codegeneration/expression_generator.py | 17 +- ffcx/codegeneration/geometry.py | 16 +- ffcx/codegeneration/integral_generator.py | 117 +++---- ffcx/codegeneration/symbols.py | 50 ++- ffcx/codegeneration/utils.py | 34 ++ 11 files changed, 464 insertions(+), 166 deletions(-) create mode 100644 ffcx/codegeneration/C/c_implementation.py create mode 100644 ffcx/codegeneration/utils.py diff --git a/ffcx/codegeneration/C/c_implementation.py b/ffcx/codegeneration/C/c_implementation.py new file mode 100644 index 000000000..f735673c8 --- /dev/null +++ b/ffcx/codegeneration/C/c_implementation.py @@ -0,0 +1,332 @@ +# Copyright (C) 2023 Chris Richardson +# +# This file is part of FFCx. (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +import warnings +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.utils import scalar_to_value_type + +math_table = { + "double": { + "sqrt": "sqrt", + "abs": "fabs", + "cos": "cos", + "sin": "sin", + "tan": "tan", + "acos": "acos", + "asin": "asin", + "atan": "atan", + "cosh": "cosh", + "sinh": "sinh", + "tanh": "tanh", + "acosh": "acosh", + "asinh": "asinh", + "atanh": "atanh", + "power": "pow", + "exp": "exp", + "ln": "log", + "erf": "erf", + "atan_2": "atan2", + "min_value": "fmin", + "max_value": "fmax", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "float": { + "sqrt": "sqrtf", + "abs": "fabsf", + "cos": "cosf", + "sin": "sinf", + "tan": "tanf", + "acos": "acosf", + "asin": "asinf", + "atan": "atanf", + "cosh": "coshf", + "sinh": "sinhf", + "tanh": "tanhf", + "acosh": "acoshf", + "asinh": "asinhf", + "atanh": "atanhf", + "power": "powf", + "exp": "expf", + "ln": "logf", + "erf": "erff", + "atan_2": "atan2f", + "min_value": "fminf", + "max_value": "fmaxf", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "long double": { + "sqrt": "sqrtl", + "abs": "fabsl", + "cos": "cosl", + "sin": "sinl", + "tan": "tanl", + "acos": "acosl", + "asin": "asinl", + "atan": "atanl", + "cosh": "coshl", + "sinh": "sinhl", + "tanh": "tanhl", + "acosh": "acoshl", + "asinh": "asinhl", + "atanh": "atanhl", + "power": "powl", + "exp": "expl", + "ln": "logl", + "erf": "erfl", + "atan_2": "atan2l", + "min_value": "fminl", + "max_value": "fmaxl", + }, + "double _Complex": { + "sqrt": "csqrt", + "abs": "cabs", + "cos": "ccos", + "sin": "csin", + "tan": "ctan", + "acos": "cacos", + "asin": "casin", + "atan": "catan", + "cosh": "ccosh", + "sinh": "csinh", + "tanh": "ctanh", + "acosh": "cacosh", + "asinh": "casinh", + "atanh": "catanh", + "power": "cpow", + "exp": "cexp", + "ln": "clog", + "real": "creal", + "imag": "cimag", + "conj": "conj", + "max_value": "fmax", + "min_value": "fmin", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "float _Complex": { + "sqrt": "csqrtf", + "abs": "cabsf", + "cos": "ccosf", + "sin": "csinf", + "tan": "ctanf", + "acos": "cacosf", + "asin": "casinf", + "atan": "catanf", + "cosh": "ccoshf", + "sinh": "csinhf", + "tanh": "ctanhf", + "acosh": "cacoshf", + "asinh": "casinhf", + "atanh": "catanhf", + "power": "cpowf", + "exp": "cexpf", + "ln": "clogf", + "real": "crealf", + "imag": "cimagf", + "conj": "conjf", + "max_value": "fmaxf", + "min_value": "fminf", + "bessel_y": "yn", + "bessel_j": "jn", + }, +} + + +def build_initializer_lists(values): + arr = "{" + if len(values.shape) == 1: + arr += ", ".join(str(v) for v in values) + elif len(values.shape) > 1: + arr += ",\n ".join(build_initializer_lists(v) for v in values) + arr += "}" + return arr + + +class CFormatter(object): + def __init__(self, scalar) -> None: + self.scalar_type = scalar + self.real_type = scalar_to_value_type(scalar) + + def format_statement_list(self, slist) -> str: + return "".join(self.c_format(s) for s in slist.statements) + + def format_comment(self, c) -> str: + return "// " + c.comment + "\n" + + def format_array_decl(self, arr) -> str: + dtype = arr.symbol.dtype + assert dtype is not None + + if dtype == L.DataType.SCALAR: + typename = self.scalar_type + elif dtype == L.DataType.REAL: + typename = self.real_type + else: + raise ValueError(f"Invalid dtype: {dtype}") + + symbol = self.c_format(arr.symbol) + dims = "".join([f"[{i}]" for i in arr.sizes]) + if arr.values is None: + assert arr.const is False + return f"{typename} {symbol}{dims};\n" + + vals = build_initializer_lists(arr.values) + cstr = "static const " if arr.const else "" + return f"{cstr}{typename} {symbol}{dims} = {vals};\n" + + def format_array_access(self, arr) -> str: + name = self.c_format(arr.array) + indices = f"[{']['.join(self.c_format(i) for i in arr.indices)}]" + return f"{name}{indices}" + + def format_variable_decl(self, v) -> str: + val = self.c_format(v.value) + symbol = self.c_format(v.symbol) + assert v.symbol.dtype + if v.symbol.dtype == L.DataType.SCALAR: + typename = self.scalar_type + elif v.symbol.dtype == L.DataType.REAL: + typename = self.real_type + return f"{typename} {symbol} = {val};\n" + + def format_nary_op(self, oper) -> str: + # Format children + args = [self.c_format(arg) for arg in oper.args] + + # Apply parentheses + for i in range(len(args)): + if oper.args[i].precedence >= oper.precedence: + args[i] = "(" + args[i] + ")" + + # Return combined string + return f" {oper.op} ".join(args) + + def format_binary_op(self, oper) -> str: + # Format children + lhs = self.c_format(oper.lhs) + rhs = self.c_format(oper.rhs) + + # Apply parentheses + if oper.lhs.precedence >= oper.precedence: + lhs = f"({lhs})" + if oper.rhs.precedence >= oper.precedence: + rhs = f"({rhs})" + + # Return combined string + return f"{lhs} {oper.op} {rhs}" + + def format_neg(self, val) -> str: + arg = self.c_format(val.arg) + return f"-{arg}" + + def format_not(self, val) -> str: + arg = self.c_format(val.arg) + return f"{val.op}({arg})" + + def format_literal_float(self, val) -> str: + return f"{val.value}" + + def format_literal_int(self, val) -> str: + return f"{val.value}" + + def format_for_range(self, r) -> str: + begin = self.c_format(r.begin) + end = self.c_format(r.end) + index = self.c_format(r.index) + output = f"for (int {index} = {begin}; {index} < {end}; ++{index})\n" + output += "{\n" + body = self.c_format(r.body) + for line in body.split("\n"): + if len(line) > 0: + output += f" {line}\n" + output += "}\n" + return output + + def format_statement(self, s) -> str: + return self.c_format(s.expr) + + def format_assign(self, expr) -> str: + rhs = self.c_format(expr.rhs) + lhs = self.c_format(expr.lhs) + return f"{lhs} {expr.op} {rhs};\n" + + def format_conditional(self, s) -> str: + # Format children + c = self.c_format(s.condition) + t = self.c_format(s.true) + f = self.c_format(s.false) + + # Apply parentheses + if s.condition.precedence >= s.precedence: + c = "(" + c + ")" + if s.true.precedence >= s.precedence: + t = "(" + t + ")" + if s.false.precedence >= s.precedence: + f = "(" + f + ")" + + # Return combined string + return c + " ? " + t + " : " + f + + def format_symbol(self, s) -> str: + return f"{s.name}" + + def format_math_function(self, c) -> str: + # Get a table of functions for this type, if available + arg_type = self.scalar_type + if hasattr(c.args[0], "dtype"): + if c.args[0].dtype == L.DataType.REAL: + arg_type = self.real_type + else: + warnings.warn(f"Syntax item without dtype {c.args[0]}") + + dtype_math_table = math_table.get(arg_type, {}) + + # Get a function from the table, if available, else just use bare name + func = dtype_math_table.get(c.function, c.function) + args = ", ".join(self.c_format(arg) for arg in c.args) + return f"{func}({args})" + + c_impl = { + "StatementList": format_statement_list, + "Comment": format_comment, + "ArrayDecl": format_array_decl, + "ArrayAccess": format_array_access, + "VariableDecl": format_variable_decl, + "ForRange": format_for_range, + "Statement": format_statement, + "Assign": format_assign, + "AssignAdd": format_assign, + "Product": format_nary_op, + "Neg": format_neg, + "Sum": format_nary_op, + "Add": format_binary_op, + "Sub": format_binary_op, + "Mul": format_binary_op, + "Div": format_binary_op, + "Not": format_not, + "LiteralFloat": format_literal_float, + "LiteralInt": format_literal_int, + "Symbol": format_symbol, + "Conditional": format_conditional, + "MathFunction": format_math_function, + "And": format_binary_op, + "Or": format_binary_op, + "NE": format_binary_op, + "EQ": format_binary_op, + "GE": format_binary_op, + "LE": format_binary_op, + "GT": format_binary_op, + "LT": format_binary_op, + } + + def c_format(self, s) -> str: + name = s.__class__.__name__ + try: + return self.c_impl[name](self, s) + except KeyError: + raise RuntimeError("Unknown statement: ", name) diff --git a/ffcx/codegeneration/C/expressions.py b/ffcx/codegeneration/C/expressions.py index a994b8c5c..530c0ab2a 100644 --- a/ffcx/codegeneration/C/expressions.py +++ b/ffcx/codegeneration/C/expressions.py @@ -9,7 +9,7 @@ from ffcx.codegeneration.C import expressions_template from ffcx.codegeneration.expression_generator import ExpressionGenerator from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.format_lines import format_indented_lines +from ffcx.codegeneration.C.c_implementation import CFormatter from ffcx.naming import cdtype_to_numpy, scalar_to_value_type logger = logging.getLogger("ffcx") @@ -36,8 +36,8 @@ def generator(ir, options): parts = eg.generate() - body = format_indented_lines(parts.cs_format(), 1) - d["tabulate_expression"] = body + CF = CFormatter(options["scalar_type"]) + d["tabulate_expression"] = CF.c_format(parts) if len(ir.original_coefficient_positions) > 0: d["original_coefficient_positions"] = f"original_coefficient_positions_{ir.name}" diff --git a/ffcx/codegeneration/C/integrals.py b/ffcx/codegeneration/C/integrals.py index b1b73c525..c6916591a 100644 --- a/ffcx/codegeneration/C/integrals.py +++ b/ffcx/codegeneration/C/integrals.py @@ -9,7 +9,7 @@ from ffcx.codegeneration.integral_generator import IntegralGenerator from ffcx.codegeneration.C import integrals_template as ufcx_integrals from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.format_lines import format_indented_lines +from ffcx.codegeneration.C.c_implementation import CFormatter from ffcx.naming import cdtype_to_numpy, scalar_to_value_type logger = logging.getLogger("ffcx") @@ -36,7 +36,8 @@ def generator(ir, options): parts = ig.generate() # Format code as string - body = format_indented_lines(parts.cs_format(ir.precision), 1) + CF = CFormatter(options["scalar_type"]) + body = CF.c_format(parts) # Generate generic FFCx code snippets and add specific parts code = {} diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index 9d1dcc985..b88de608b 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -11,6 +11,7 @@ import ufl import basix.ufl from ffcx.element_interface import convert_element +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -18,12 +19,11 @@ class FFCXBackendAccess(object): """FFCx specific cpp formatter class.""" - def __init__(self, ir, language, symbols, options): + def __init__(self, ir, symbols, options): # Store ir and options self.entitytype = ir.entitytype self.integral_type = ir.integral_type - self.language = language self.symbols = symbols self.options = options @@ -178,10 +178,9 @@ def jacobian(self, e, mt, tabledata, num_points): return self.symbols.J_component(mt) def reference_cell_volume(self, e, mt, tabledata, access): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - return L.Symbol(f"{cellname}_reference_cell_volume") + return L.Symbol(f"{cellname}_reference_cell_volume", dtype=L.DataType.REAL) else: raise RuntimeError(f"Unhandled cell types {cellname}.") @@ -189,7 +188,7 @@ def reference_facet_volume(self, e, mt, tabledata, access): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - return L.Symbol(f"{cellname}_reference_facet_volume") + return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL) else: raise RuntimeError(f"Unhandled cell types {cellname}.") @@ -197,7 +196,7 @@ def reference_normal(self, e, mt, tabledata, access): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_facet_normals") + table = L.Symbol(f"{cellname}_reference_facet_normals", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]] else: @@ -207,7 +206,7 @@ def cell_facet_jacobian(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_facet_jacobian") + table = L.Symbol(f"{cellname}_reference_facet_jacobian", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]][mt.component[1]] elif cellname == "interval": @@ -219,7 +218,7 @@ def reference_cell_edge_vectors(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_edge_vectors") + table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) return table[mt.component[0]][mt.component[1]] elif cellname == "interval": raise RuntimeError("The reference cell edge vectors doesn't make sense for interval cell.") @@ -230,7 +229,7 @@ def reference_facet_edge_vectors(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("tetrahedron", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_edge_vectors") + table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]][mt.component[1]] elif cellname in ("interval", "triangle", "quadrilateral"): @@ -246,7 +245,7 @@ def facet_orientation(self, e, mt, tabledata, num_points): if cellname not in ("interval", "triangle", "tetrahedron"): raise RuntimeError(f"Unhandled cell types {cellname}.") - table = L.Symbol(f"{cellname}_facet_orientations") + table = L.Symbol(f"{cellname}_facet_orientations", dtype=L.DataType.INT) facet = self.symbols.entity("facet", mt.restriction) return table[facet] diff --git a/ffcx/codegeneration/backend.py b/ffcx/codegeneration/backend.py index 0b9c5d8d2..b874196e4 100644 --- a/ffcx/codegeneration/backend.py +++ b/ffcx/codegeneration/backend.py @@ -5,11 +5,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Collection of FFCx specific pieces for the code generation phase.""" -import types - -import ffcx.codegeneration.C.cnodes from ffcx.codegeneration.access import FFCXBackendAccess -from ffcx.codegeneration.C.ufl_to_cnodes import UFL2CNodesTranslatorCpp from ffcx.codegeneration.definitions import FFCXBackendDefinitions from ffcx.codegeneration.symbols import FFCXBackendSymbols @@ -19,19 +15,12 @@ class FFCXBackend(object): def __init__(self, ir, options): - # This is the seam where cnodes/C is chosen for the FFCx backend - self.language: types.ModuleType = ffcx.codegeneration.C.cnodes - scalar_type = options["scalar_type"] - self.ufl_to_language = UFL2CNodesTranslatorCpp(self.language, scalar_type) - coefficient_numbering = ir.coefficient_numbering coefficient_offsets = ir.coefficient_offsets original_constant_offsets = ir.original_constant_offsets - self.symbols = FFCXBackendSymbols(self.language, coefficient_numbering, + self.symbols = FFCXBackendSymbols(coefficient_numbering, coefficient_offsets, original_constant_offsets) - self.definitions = FFCXBackendDefinitions(ir, self.language, - self.symbols, options) - self.access = FFCXBackendAccess(ir, self.language, self.symbols, - options) + self.definitions = FFCXBackendDefinitions(ir, self.symbols, options) + self.access = FFCXBackendAccess(ir, self.symbols, options) diff --git a/ffcx/codegeneration/definitions.py b/ffcx/codegeneration/definitions.py index 1b26de95f..07390fb1c 100644 --- a/ffcx/codegeneration/definitions.py +++ b/ffcx/codegeneration/definitions.py @@ -9,7 +9,7 @@ import ufl from ffcx.element_interface import convert_element -from ffcx.naming import scalar_to_value_type +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -17,11 +17,10 @@ class FFCXBackendDefinitions(object): """FFCx specific code definitions.""" - def __init__(self, ir, language, symbols, options): + def __init__(self, ir, symbols, options): # Store ir and options self.integral_type = ir.integral_type self.entitytype = ir.entitytype - self.language = language self.symbols = symbols self.options = options @@ -64,8 +63,6 @@ def get(self, t, mt, tabledata, quadrature_rule, access): def coefficient(self, t, mt, tabledata, quadrature_rule, access): """Return definition code for coefficients.""" - L = self.language - ttype = tabledata.ttype num_dofs = tabledata.values.shape[3] bs = tabledata.block_size @@ -106,7 +103,7 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access): dof_access = self.symbols.coefficient_dof_access(mt.terminal, ic * bs + begin) body = [L.AssignAdd(access, dof_access * FE[ic])] - code += [L.VariableDecl(self.options["scalar_type"], access, 0.0)] + code += [L.VariableDecl(access, 0.0)] code += [L.ForRange(ic, 0, num_dofs, body)] return pre_code, code @@ -119,8 +116,6 @@ def constant(self, t, mt, tabledata, quadrature_rule, access): def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, access): """Define x or J as a linear combination of coordinate dofs with given table data.""" - L = self.language - # Get properties of domain domain = ufl.domain.extract_unique_domain(mt.terminal) coordinate_element = domain.ufl_coordinate_element() @@ -140,7 +135,7 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc # Get access to element table FE = self.symbols.element_table(tabledata, self.entitytype, mt.restriction) ic = self.symbols.coefficient_dof_sum_index() - dof_access = self.symbols.S("coordinate_dofs") + dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL) # coordinate dofs is always 3d dim = 3 @@ -148,11 +143,9 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc if mt.restriction == "-": offset = num_scalar_dofs * dim - value_type = scalar_to_value_type(self.options["scalar_type"]) - code = [] body = [L.AssignAdd(access, dof_access[ic * dim + begin + offset] * FE[ic])] - code += [L.VariableDecl(f"{value_type}", access, 0.0)] + code += [L.VariableDecl(access, 0.0)] code += [L.ForRange(ic, 0, num_scalar_dofs, body)] return [], code diff --git a/ffcx/codegeneration/expression_generator.py b/ffcx/codegeneration/expression_generator.py index 553e8b315..28ee90e29 100644 --- a/ffcx/codegeneration/expression_generator.py +++ b/ffcx/codegeneration/expression_generator.py @@ -12,7 +12,8 @@ import ufl from ffcx.codegeneration import geometry from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.cnodes import CNode +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.lnodes import LNode from ffcx.ir.representation import ExpressionIR from ffcx.naming import scalar_to_value_type @@ -27,7 +28,7 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend): self.ir = ir self.backend = backend - self.scope: Dict[Any, CNode] = {} + self.scope: Dict[Any, LNode] = {} self._ufl_names: Set[Any] = set() self.symbol_counters: DefaultDict[Any, int] = collections.defaultdict(int) self.shared_symbols: Dict[Any, Any] = {} @@ -58,10 +59,8 @@ def generate(self): return L.StatementList(parts) - def generate_geometry_tables(self, float_type: str): + def generate_geometry_tables(self): """Generate static tables of geometry data.""" - L = self.backend.language - # Currently we only support circumradius ufl_geometry = { ufl.geometry.ReferenceCellVolume: "reference_cell_volume", @@ -79,24 +78,20 @@ def generate_geometry_tables(self, float_type: str): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c, float_type)) + parts.append(geometry.write_table(L, ufl_geometry[i], c)) return parts def generate_element_tables(self, float_type: str): """Generate tables of FE basis evaluated at specified points.""" - L = self.backend.language parts = [] tables = self.ir.unique_tables - - padlen = self.ir.options["padlen"] table_names = sorted(tables) for name in table_names: table = tables[name] - decl = L.ArrayDecl( - f"static const {float_type}", name, table.shape, table, padlen=padlen) + decl = L.ArrayDecl(name, table) parts += [decl] # Add leading comment if there are any tables diff --git a/ffcx/codegeneration/geometry.py b/ffcx/codegeneration/geometry.py index 2df2bcd93..271a438fa 100644 --- a/ffcx/codegeneration/geometry.py +++ b/ffcx/codegeneration/geometry.py @@ -9,23 +9,23 @@ import basix -def write_table(L, tablename, cellname, type: str): +def write_table(L, tablename, cellname): if tablename == "facet_edge_vertices": return facet_edge_vertices(L, tablename, cellname) if tablename == "reference_facet_jacobian": - return reference_facet_jacobian(L, tablename, cellname, type) + return reference_facet_jacobian(L, tablename, cellname) if tablename == "reference_cell_volume": - return reference_cell_volume(L, tablename, cellname, type) + return reference_cell_volume(L, tablename, cellname) if tablename == "reference_facet_volume": - return reference_facet_volume(L, tablename, cellname, type) + return reference_facet_volume(L, tablename, cellname) if tablename == "reference_edge_vectors": - return reference_edge_vectors(L, tablename, cellname, type) + return reference_edge_vectors(L, tablename, cellname) if tablename == "facet_reference_edge_vectors": - return facet_reference_edge_vectors(L, tablename, cellname, type) + return facet_reference_edge_vectors(L, tablename, cellname) if tablename == "reference_facet_normals": - return reference_facet_normals(L, tablename, cellname, type) + return reference_facet_normals(L, tablename, cellname) if tablename == "facet_orientation": - return facet_orientation(L, tablename, cellname, type) + return facet_orientation(L, tablename, cellname) raise ValueError(f"Unknown geometry table name: {tablename}") diff --git a/ffcx/codegeneration/integral_generator.py b/ffcx/codegeneration/integral_generator.py index 1ebcef70f..c58111737 100644 --- a/ffcx/codegeneration/integral_generator.py +++ b/ffcx/codegeneration/integral_generator.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2021 Martin Sandve Alnæs, Michal Habera, Igor Baratta +# Copyright (C) 2015-2023 Martin Sandve Alnæs, Michal Habera, Igor Baratta, Chris Richardson # # This file is part of FFCx. (https://www.fenicsproject.org) # @@ -10,11 +10,11 @@ import ufl from ffcx.codegeneration import geometry -from ffcx.codegeneration.C.cnodes import BinOp, CNode from ffcx.ir.elementtables import piecewise_ttypes from ffcx.ir.integral import BlockDataT +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.lnodes import LNode, BinOp from ffcx.ir.representationutils import QuadratureRule -from ffcx.naming import scalar_to_value_type logger = logging.getLogger("ffcx") @@ -25,11 +25,11 @@ def __init__(self, ir, backend): self.ir = ir # Backend specific plugin with attributes - # - language: for translating ufl operators to target language # - symbols: for translating ufl operators to target language # - definitions: for defining backend specific variables # - access: for accessing backend specific variables self.backend = backend + self.ufl_to_language = L.UFL2LNodes() # Set of operator names code has been generated for, used in the # end for selecting necessary includes @@ -57,7 +57,7 @@ def set_var(self, quadrature_rule, v, vaccess): Scope is determined by quadrature_rule which identifies the quadrature loop scope or None if outside quadrature loops. - v is the ufl expression and vaccess is the CNodes + v is the ufl expression and vaccess is the LNodes expression to access the value in the code. """ @@ -72,10 +72,10 @@ def get_var(self, quadrature_rule, v): If v is not found in quadrature loop scope, the piecewise scope (None) is checked. - Returns the CNodes expression to access the value in the code. + Returns the LNodes expression to access the value in the code. """ if v._ufl_is_literal_: - return self.backend.ufl_to_language.get(v) + return self.ufl_to_language.get(v) f = self.scopes[quadrature_rule].get(v) if f is None: f = self.scopes[None].get(v) @@ -83,13 +83,12 @@ def get_var(self, quadrature_rule, v): def new_temp_symbol(self, basename): """Create a new code symbol named basename + running counter.""" - L = self.backend.language name = "%s%d" % (basename, self.symbol_counters[basename]) self.symbol_counters[basename] += 1 - return L.Symbol(name) + return L.Symbol(name, dtype=L.DataType.SCALAR) def get_temp_symbol(self, tempname, key): - key = (tempname, ) + key + key = (tempname,) + key s = self.shared_symbols.get(key) defined = s is not None if not defined: @@ -104,32 +103,21 @@ def generate(self): context that matches a suitable version of the UFC tabulate_tensor signatures. """ - L = self.backend.language - # Assert that scopes are empty: expecting this to be called only # once assert not any(d for d in self.scopes.values()) parts = [] - scalar_type = self.backend.access.options["scalar_type"] - value_type = scalar_to_value_type(scalar_type) - alignment = self.ir.options['assume_aligned'] - if alignment != -1: - scalar_type = self.backend.access.options["scalar_type"] - parts += [L.VerbatimStatement(f"A = ({scalar_type}*)__builtin_assume_aligned(A, {alignment});"), - L.VerbatimStatement(f"w = (const {scalar_type}*)__builtin_assume_aligned(w, {alignment});"), - L.VerbatimStatement(f"c = (const {scalar_type}*)__builtin_assume_aligned(c, {alignment});"), - L.VerbatimStatement(f"coordinate_dofs = (const {value_type}*)__builtin_assume_aligned(coordinate_dofs, {alignment});")] # noqa # Generate the tables of quadrature points and weights - parts += self.generate_quadrature_tables(value_type) + parts += self.generate_quadrature_tables() # Generate the tables of basis function values and # pre-integrated blocks - parts += self.generate_element_tables(value_type) + parts += self.generate_element_tables() # Generate the tables of geometry data that are needed - parts += self.generate_geometry_tables(value_type) + parts += self.generate_geometry_tables() # Loop generation code will produce parts to go before # quadloops, to define the quadloops, and to go after the @@ -160,11 +148,9 @@ def generate(self): return L.StatementList(parts) - def generate_quadrature_tables(self, value_type: str) -> List[str]: + def generate_quadrature_tables(self): """Generate static tables of quadrature points and weights.""" - L = self.backend.language - - parts: List[str] = [] + parts = [] # No quadrature tables for custom (given argument) or point # (evaluation in single vertex) @@ -172,25 +158,18 @@ def generate_quadrature_tables(self, value_type: str) -> List[str]: if self.ir.integral_type in skip: return parts - padlen = self.ir.options["padlen"] - # Loop over quadrature rules for quadrature_rule, integrand in self.ir.integrand.items(): - num_points = quadrature_rule.weights.shape[0] - # Generate quadrature weights array wsym = self.backend.symbols.weights_table(quadrature_rule) - parts += [L.ArrayDecl(f"static const {value_type}", wsym, num_points, - quadrature_rule.weights, padlen=padlen)] + parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)] # Add leading comment if there are any tables parts = L.commented_code_list(parts, "Quadrature rules") return parts - def generate_geometry_tables(self, float_type: str): + def generate_geometry_tables(self): """Generate static tables of geometry data.""" - L = self.backend.language - ufl_geometry = { ufl.geometry.FacetEdgeVectors: "facet_edge_vertices", ufl.geometry.CellFacetJacobian: "reference_facet_jacobian", @@ -214,17 +193,15 @@ def generate_geometry_tables(self, float_type: str): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c, float_type)) + parts.append(geometry.write_table(L, ufl_geometry[i], c)) return parts - def generate_element_tables(self, float_type: str): + def generate_element_tables(self): """Generate static tables with precomputed element basisfunction values in quadrature points.""" - L = self.backend.language parts = [] tables = self.ir.unique_tables table_types = self.ir.unique_table_types - padlen = self.ir.options["padlen"] if self.ir.integral_type in ufl.custom_integral_types: # Define only piecewise tables table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes] @@ -234,7 +211,7 @@ def generate_element_tables(self, float_type: str): for name in table_names: table = tables[name] - parts += self.declare_table(name, table, padlen, float_type) + parts += self.declare_table(name, table) # Add leading comment if there are any tables parts = L.commented_code_list(parts, [ @@ -242,19 +219,18 @@ def generate_element_tables(self, float_type: str): "FE* dimensions: [permutation][entities][points][dofs]"]) return parts - def declare_table(self, name, table, padlen, value_type: str): + def declare_table(self, name, table): """Declare a table. If the dof dimensions of the table have dof rotations, apply these rotations. """ - L = self.backend.language - return [L.ArrayDecl(f"static const {value_type}", name, table.shape, table, padlen=padlen)] + table_symbol = L.Symbol(name, dtype=L.DataType.REAL) + return [L.ArrayDecl(table_symbol, values=table, const=True)] def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): """Generate quadrature loop with for this quadrature_rule.""" - L = self.backend.language # Generate varying partition pre_definitions, body = self.generate_varying_partition(quadrature_rule) @@ -278,12 +254,10 @@ def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): return pre_definitions, preparts, quadparts def generate_piecewise_partition(self, quadrature_rule): - L = self.backend.language - # Get annotated graph of factorisation F = self.ir.integrand[quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}") + arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) pre_definitions, parts = self.generate_partition(arraysymbol, F, "piecewise", None) assert len(pre_definitions) == 0, "Quadrature independent code should have not pre-definitions" parts = L.commented_code_list( @@ -292,19 +266,17 @@ def generate_piecewise_partition(self, quadrature_rule): return parts def generate_varying_partition(self, quadrature_rule): - L = self.backend.language # Get annotated graph of factorisation F = self.ir.integrand[quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}") + arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) pre_definitions, parts = self.generate_partition(arraysymbol, F, "varying", quadrature_rule) parts = L.commented_code_list(parts, f"Varying computations for quadrature rule {quadrature_rule.id()}") return pre_definitions, parts def generate_partition(self, symbol, F, mode, quadrature_rule): - L = self.backend.language definitions = dict() pre_definitions = dict() @@ -322,7 +294,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # cache if not self.get_var(quadrature_rule, v): if v._ufl_is_literal_: - vaccess = self.backend.ufl_to_language.get(v) + vaccess = self.ufl_to_language.get(v) elif mt is not None: # All finite element based terminals have table # data, as well as some, but not all, of the @@ -352,7 +324,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # Mapping UFL operator to target language self._ufl_names.add(v._ufl_handler_name_) - vexpr = self.backend.ufl_to_language.get(v, *vops) + vexpr = self.ufl_to_language.get(v, *vops) # Create a new intermediate for each subexpression # except boolean conditions and its childs @@ -379,9 +351,8 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): vaccess = symbol[j] intermediates.append(L.Assign(vaccess, vexpr)) else: - scalar_type = self.backend.access.options["scalar_type"] vaccess = L.Symbol("%s_%d" % (symbol.name, j)) - intermediates.append(L.VariableDecl(f"const {scalar_type}", vaccess, vexpr)) + intermediates.append(L.VariableDecl(vaccess, vexpr)) # Store access node for future reference self.set_var(quadrature_rule, v, vaccess) @@ -393,9 +364,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): if intermediates: if use_symbol_array: - padlen = self.ir.options["padlen"] - parts += [L.ArrayDecl(self.backend.access.options["scalar_type"], - symbol, len(intermediates), padlen=padlen)] + parts += [L.ArrayDecl(symbol, sizes=len(intermediates))] parts += intermediates return pre_definitions, parts @@ -467,11 +436,9 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, Should be called with quadrature_rule=None for quadloop-independent blocks. """ - L = self.backend.language - # The parts to return - preparts: List[CNode] = [] - quadparts: List[CNode] = [] + preparts: List[LNode] = [] + quadparts: List[LNode] = [] # RHS expressions grouped by LHS "dofmap" rhs_expressions = collections.defaultdict(list) @@ -523,8 +490,7 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, key = (quadrature_rule, factor_index, blockdata.all_factors_piecewise) fw, defined = self.get_temp_symbol("fw", key) if not defined: - scalar_type = self.backend.access.options["scalar_type"] - quadparts.append(L.VariableDecl(f"const {scalar_type}", fw, fw_rhs)) + quadparts.append(L.VariableDecl(fw, fw_rhs)) assert not blockdata.transposed, "Not handled yet" A_shape = self.ir.tensor_shape @@ -551,7 +517,7 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, # List of statements to keep in the inner loop keep = collections.defaultdict(list) # List of temporary array declarations - pre_loop: List[CNode] = [] + pre_loop: List[LNode] = [] # List of loop invariant expressions to hoist hoist: List[BinOp] = [] @@ -577,34 +543,29 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, # floating point operations (factorize expressions by # grouping) for statement in hoist_rhs: - sum = [] - for rhs in hoist_rhs[statement]: - sum.append(L.float_product(rhs)) - sum = L.Sum(sum) + sum = L.Sum([L.float_product(rhs) for rhs in hoist_rhs[statement]]) lhs = None for h in hoist: - if (h.rhs == sum): + if h.rhs == sum: lhs = h.lhs break if lhs: keep[indices].append(L.float_product([statement, lhs])) else: t = self.new_temp_symbol("t") - scalar_type = self.backend.access.options["scalar_type"] - pre_loop.append(L.ArrayDecl(scalar_type, t, blockdims[0])) + pre_loop.append(L.ArrayDecl(t, sizes=blockdims[0])) keep[indices].append(L.float_product([statement, t[B_indices[0]]])) hoist.append(L.Assign(t[B_indices[i - 1]], sum)) else: keep[indices] = rhs_expressions[indices] - hoist_code: List[CNode] = [L.ForRange(B_indices[0], 0, blockdims[0], body=hoist)] if hoist else [] + hoist_code: List[LNode] = [L.ForRange(B_indices[0], 0, blockdims[0], body=hoist)] if hoist else [] - body: List[CNode] = [] + body: List[LNode] = [] for indices in keep: - sum = L.Sum(keep[indices]) - body.append(L.AssignAdd(A[indices], sum)) + body.append(L.AssignAdd(A[indices], L.Sum(keep[indices]))) for i in reversed(range(block_rank)): body = [L.ForRange(B_indices[i], 0, blockdims[i], body=body)] @@ -626,8 +587,6 @@ def fuse_loops(self, definitions): determine how many loops should fuse at a time. """ - L = self.backend.language - loops = collections.defaultdict(list) pre_loop = [] for access, definition in definitions.items(): diff --git a/ffcx/codegeneration/symbols.py b/ffcx/codegeneration/symbols.py index 7630a0d68..41d6dd20b 100644 --- a/ffcx/codegeneration/symbols.py +++ b/ffcx/codegeneration/symbols.py @@ -7,6 +7,7 @@ import logging import ufl +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -60,10 +61,8 @@ def format_mt_name(basename, mt): class FFCXBackendSymbols(object): """FFCx specific symbol definitions. Provides non-ufl symbols.""" - def __init__(self, language, coefficient_numbering, coefficient_offsets, + def __init__(self, coefficient_numbering, coefficient_offsets, original_constant_offsets): - self.L = language - self.S = self.L.Symbol self.coefficient_numbering = coefficient_numbering self.coefficient_offsets = coefficient_offsets @@ -71,71 +70,71 @@ def __init__(self, language, coefficient_numbering, coefficient_offsets, def element_tensor(self): """Symbol for the element tensor itself.""" - return self.S("A") + return L.Symbol("A") def entity(self, entitytype, restriction): """Entity index for lookup in element tables.""" if entitytype == "cell": # Always 0 for cells (even with restriction) - return self.L.LiteralInt(0) + return L.LiteralInt(0) elif entitytype == "facet": postfix = "[0]" if restriction == "-": postfix = "[1]" - return self.S("entity_local_index" + postfix) + return L.Symbol("entity_local_index" + postfix, dtype=L.DataType.INT) elif entitytype == "vertex": - return self.S("entity_local_index[0]") + return L.Symbol("entity_local_index[0]", dtype=L.DataType.INT) else: logging.exception(f"Unknown entitytype {entitytype}") def argument_loop_index(self, iarg): """Loop index for argument #iarg.""" indices = ["i", "j", "k", "l"] - return self.S(indices[iarg]) + return L.Symbol(indices[iarg], dtype=L.DataType.INT) def coefficient_dof_sum_index(self): """Index for loops over coefficient dofs, assumed to never be used in two nested loops.""" - return self.S("ic") + return L.Symbol("ic", dtype=L.DataType.INT) def quadrature_loop_index(self): """Reusing a single index name for all quadrature loops, assumed not to be nested.""" - return self.S("iq") + return L.Symbol("iq", dtype=L.DataType.INT) def quadrature_permutation(self, index): """Quadrature permutation, as input to the function.""" - return self.S("quadrature_permutation")[index] + return L.Symbol("quadrature_permutation", dtype=L.DataType.INT)[index] def custom_weights_table(self): """Table for chunk of custom quadrature weights (including cell measure scaling).""" - return self.S("weights_chunk") + return L.Symbol("weights_chunk", dtype=L.DataType.REAL) def custom_points_table(self): """Table for chunk of custom quadrature points (physical coordinates).""" - return self.S("points_chunk") + return L.Symbol("points_chunk", dtype=L.DataType.REAL) def weights_table(self, quadrature_rule): """Table of quadrature weights.""" - return self.S(f"weights_{quadrature_rule.id()}") + return L.Symbol(f"weights_{quadrature_rule.id()}", dtype=L.DataType.REAL) def points_table(self, quadrature_rule): """Table of quadrature points (points on the reference integration entity).""" - return self.S(f"points_{quadrature_rule.id()}") + return L.Symbol(f"points_{quadrature_rule.id()}", dtype=L.DataType.REAL) def x_component(self, mt): """Physical coordinate component.""" - return self.S(format_mt_name("x", mt)) + return L.Symbol(format_mt_name("x", mt), dtype=L.DataType.REAL) def J_component(self, mt): """Jacobian component.""" # FIXME: Add domain number! - return self.S(format_mt_name("J", mt)) + return L.Symbol(format_mt_name("J", mt), dtype=L.DataType.REAL) def domain_dof_access(self, dof, component, gdim, num_scalar_dofs, restriction): # FIXME: Add domain number or offset! offset = 0 if restriction == "-": offset = num_scalar_dofs * 3 - vc = self.S("coordinate_dofs") + vc = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL) return vc[3 * dof + component + offset] def domain_dofs_access(self, gdim, num_scalar_dofs, restriction): @@ -147,14 +146,14 @@ def domain_dofs_access(self, gdim, num_scalar_dofs, restriction): def coefficient_dof_access(self, coefficient, dof_index): offset = self.coefficient_offsets[coefficient] - w = self.S("w") + w = L.Symbol("w", dtype=L.DataType.SCALAR) return w[offset + dof_index] def coefficient_dof_access_blocked(self, coefficient: ufl.Coefficient, index, block_size, dof_offset): coeff_offset = self.coefficient_offsets[coefficient] - w = self.S("w") - _w = self.S(f"_w_{coeff_offset}_{dof_offset}") + w = L.Symbol("w", dtype=L.DataType.SCALAR) + _w = L.Symbol(f"_w_{coeff_offset}_{dof_offset}", dtype=L.DataType.SCALAR) unit_stride_access = _w[index] original_access = w[coeff_offset + index * block_size + dof_offset] return unit_stride_access, original_access @@ -162,17 +161,14 @@ def coefficient_dof_access_blocked(self, coefficient: ufl.Coefficient, index, def coefficient_value(self, mt): """Symbol for variable holding value or derivative component of coefficient.""" c = self.coefficient_numbering[mt.terminal] - return self.S(format_mt_name("w%d" % (c, ), mt)) + return L.Symbol(format_mt_name("w%d" % (c, ), mt), dtype=L.DataType.SCALAR) def constant_index_access(self, constant, index): offset = self.original_constant_offsets[constant] - c = self.S("c") + c = L.Symbol("c", dtype=L.DataType.SCALAR) return c[offset + index] - def named_table(self, name): - return self.S(name) - def element_table(self, tabledata, entitytype, restriction): entity = self.entity(entitytype, restriction) @@ -194,4 +190,4 @@ def element_table(self, tabledata, entitytype, restriction): qp = 0 # Return direct access to element table - return self.named_table(tabledata.name)[qp][entity][iq] + return L.Symbol(tabledata.name, dtype=L.DataType.REAL)[qp][entity][iq] diff --git a/ffcx/codegeneration/utils.py b/ffcx/codegeneration/utils.py new file mode 100644 index 000000000..06497c172 --- /dev/null +++ b/ffcx/codegeneration/utils.py @@ -0,0 +1,34 @@ +# Copyright (C) 2020-2023 Michal Habera and Chris Richardson +# +# This file is part of FFCx.(https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +def cdtype_to_numpy(cdtype: str): + """Map a C data type string NumPy datatype string.""" + if cdtype == "double": + return "float64" + elif cdtype == "double _Complex": + return "complex128" + elif cdtype == "float": + return "float32" + elif cdtype == "float _Complex": + return "complex64" + elif cdtype == "long double": + return "longdouble" + else: + raise RuntimeError(f"Unknown NumPy type for: {cdtype}") + + +def scalar_to_value_type(scalar_type: str) -> str: + """The C value type associated with a C scalar type. + + Args: + scalar_type: A C type. + + Returns: + The value type associated with ``scalar_type``. E.g., if + ``scalar_type`` is ``float _Complex`` the return value is 'float'. + + """ + return scalar_type.replace(' _Complex', '') From b85eb4f07cb49e6c71a8169f195f48b87ada4393 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 15:11:27 +0100 Subject: [PATCH 04/13] Work mostly on Expression --- ffcx/codegeneration/access.py | 8 --- ffcx/codegeneration/definitions.py | 4 +- ffcx/codegeneration/expression_generator.py | 48 ++++++----------- ffcx/codegeneration/geometry.py | 60 ++++++++++++--------- ffcx/codegeneration/integral_generator.py | 2 +- 5 files changed, 52 insertions(+), 70 deletions(-) diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index b88de608b..2b1658d58 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -138,7 +138,6 @@ def cell_coordinate(self, e, mt, tabledata, num_points): raise RuntimeError("Expecting reference cell coordinate to be symbolically rewritten.") def facet_coordinate(self, e, mt, tabledata, num_points): - L = self.language if mt.global_derivatives: raise RuntimeError("Not expecting derivatives of FacetCoordinate.") if mt.local_derivatives: @@ -185,7 +184,6 @@ def reference_cell_volume(self, e, mt, tabledata, access): raise RuntimeError(f"Unhandled cell types {cellname}.") def reference_facet_volume(self, e, mt, tabledata, access): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL) @@ -193,7 +191,6 @@ def reference_facet_volume(self, e, mt, tabledata, access): raise RuntimeError(f"Unhandled cell types {cellname}.") def reference_normal(self, e, mt, tabledata, access): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): table = L.Symbol(f"{cellname}_reference_facet_normals", dtype=L.DataType.REAL) @@ -203,7 +200,6 @@ def reference_normal(self, e, mt, tabledata, access): raise RuntimeError(f"Unhandled cell types {cellname}.") def cell_facet_jacobian(self, e, mt, tabledata, num_points): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): table = L.Symbol(f"{cellname}_reference_facet_jacobian", dtype=L.DataType.REAL) @@ -215,7 +211,6 @@ def cell_facet_jacobian(self, e, mt, tabledata, num_points): raise RuntimeError(f"Unhandled cell types {cellname}.") def reference_cell_edge_vectors(self, e, mt, tabledata, num_points): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) @@ -226,7 +221,6 @@ def reference_cell_edge_vectors(self, e, mt, tabledata, num_points): raise RuntimeError(f"Unhandled cell types {cellname}.") def reference_facet_edge_vectors(self, e, mt, tabledata, num_points): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("tetrahedron", "hexahedron"): table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) @@ -240,7 +234,6 @@ def reference_facet_edge_vectors(self, e, mt, tabledata, num_points): raise RuntimeError(f"Unhandled cell types {cellname}.") def facet_orientation(self, e, mt, tabledata, num_points): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname not in ("interval", "triangle", "tetrahedron"): raise RuntimeError(f"Unhandled cell types {cellname}.") @@ -312,7 +305,6 @@ def cell_edge_vectors(self, e, mt, tabledata, num_points): ) def facet_edge_vectors(self, e, mt, tabledata, num_points): - L = self.language # Get properties of domain domain = ufl.domain.extract_unique_domain(mt.terminal) diff --git a/ffcx/codegeneration/definitions.py b/ffcx/codegeneration/definitions.py index 07390fb1c..3fbbacbda 100644 --- a/ffcx/codegeneration/definitions.py +++ b/ffcx/codegeneration/definitions.py @@ -96,8 +96,8 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access): # If a map is necessary from stride 1 to bs, the code must be added before the quadrature loop. if dof_access_map: - pre_code += [L.ArrayDecl(self.options["scalar_type"], dof_access.array, num_dofs)] - pre_body = L.Assign(dof_access, dof_access_map) + pre_code += [L.ArrayDecl(dof_access.array, sizes=num_dofs)] + pre_body = [L.Assign(dof_access, dof_access_map)] pre_code += [L.ForRange(ic, 0, num_dofs, pre_body)] else: dof_access = self.symbols.coefficient_dof_access(mt.terminal, ic * bs + begin) diff --git a/ffcx/codegeneration/expression_generator.py b/ffcx/codegeneration/expression_generator.py index 28ee90e29..54af21215 100644 --- a/ffcx/codegeneration/expression_generator.py +++ b/ffcx/codegeneration/expression_generator.py @@ -15,7 +15,6 @@ import ffcx.codegeneration.lnodes as L from ffcx.codegeneration.lnodes import LNode from ffcx.ir.representation import ExpressionIR -from ffcx.naming import scalar_to_value_type logger = logging.getLogger("ffcx") @@ -28,6 +27,7 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend): self.ir = ir self.backend = backend + self.ufl_to_language = L.UFL2LNodes() self.scope: Dict[Any, LNode] = {} self._ufl_names: Set[Any] = set() self.symbol_counters: DefaultDict[Any, int] = collections.defaultdict(int) @@ -35,15 +35,10 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend): self.quadrature_rule = list(self.ir.integrand.keys())[0] def generate(self): - L = self.backend.language - parts = [] - scalar_type = self.backend.access.options["scalar_type"] - value_type = scalar_to_value_type(scalar_type) - - parts += self.generate_element_tables(value_type) + parts += self.generate_element_tables() # Generate the tables of geometry data that are needed - parts += self.generate_geometry_tables(value_type) + parts += self.generate_geometry_tables() parts += self.generate_piecewise_partition() all_preparts = [] @@ -78,11 +73,11 @@ def generate_geometry_tables(self): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c)) + parts.append(geometry.write_table(ufl_geometry[i], c)) return parts - def generate_element_tables(self, float_type: str): + def generate_element_tables(self): """Generate tables of FE basis evaluated at specified points.""" parts = [] @@ -91,7 +86,8 @@ def generate_element_tables(self, float_type: str): for name in table_names: table = tables[name] - decl = L.ArrayDecl(name, table) + symbol = L.Symbol(name, dtype=L.DataType.REAL) + decl = L.ArrayDecl(symbol, sizes=table.shape, values=table, const=True) parts += [decl] # Add leading comment if there are any tables @@ -107,8 +103,6 @@ def generate_quadrature_loop(self): In the context of expressions quadrature loop is not accumulated. """ - L = self.backend.language - # Generate varying partition body = self.generate_varying_partition() body = L.commented_code_list( @@ -133,12 +127,10 @@ def generate_quadrature_loop(self): def generate_varying_partition(self): """Generate factors of blocks which are not cellwise constant.""" - L = self.backend.language - # Get annotated graph of factorisation F = self.ir.integrand[self.quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}") + arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR) parts = self.generate_partition(arraysymbol, F, "varying") parts = L.commented_code_list( parts, f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}") @@ -146,12 +138,10 @@ def generate_varying_partition(self): def generate_piecewise_partition(self): """Generate factors of blocks which are constant (i.e. do not depend on quadrature points).""" - L = self.backend.language - # Get annotated graph of factorisation F = self.ir.integrand[self.quadrature_rule]["factorization"] - arraysymbol = L.Symbol("sp") + arraysymbol = L.Symbol("sp", dtype=L.DataType.SCALAR) parts = self.generate_partition(arraysymbol, F, "piecewise") parts = L.commented_code_list(parts, "Unstructured piecewise computations") return parts @@ -183,8 +173,6 @@ def generate_dofblock_partition(self): def generate_block_parts(self, blockmap, blockdata): """Generate and return code parts for a given block.""" - L = self.backend.language - # The parts to return preparts = [] quadparts = [] @@ -283,8 +271,6 @@ def get_arg_factors(self, blockdata, block_rank, indices): Indices used to index element tables """ - L = self.backend.language - arg_factors = [] for i in range(block_rank): mad = blockdata.ma_data[i] @@ -304,21 +290,18 @@ def get_arg_factors(self, blockdata, block_rank, indices): def new_temp_symbol(self, basename): """Create a new code symbol named basename + running counter.""" - L = self.backend.language name = "%s%d" % (basename, self.symbol_counters[basename]) self.symbol_counters[basename] += 1 - return L.Symbol(name) + return L.Symbol(name, dtype=L.DataType.SCALAR) def get_var(self, v): if v._ufl_is_literal_: - return self.backend.ufl_to_language.get(v) + return self.ufl_to_language.get(v) f = self.scope.get(v) return f def generate_partition(self, symbol, F, mode): """Generate computations of factors of blocks.""" - L = self.backend.language - definitions = [] pre_definitions = dict() intermediates = [] @@ -332,7 +315,7 @@ def generate_partition(self, symbol, F, mode): mt = attr.get('mt') if v._ufl_is_literal_: - vaccess = self.backend.ufl_to_language.get(v) + vaccess = self.ufl_to_language.get(v) elif mt is not None: # All finite element based terminals have table data, as well # as some, but not all, of the symbolic geometric terminals @@ -361,7 +344,7 @@ def generate_partition(self, symbol, F, mode): # Mapping UFL operator to target language self._ufl_names.add(v._ufl_handler_name_) - vexpr = self.backend.ufl_to_language.get(v, *vops) + vexpr = self.ufl_to_language.get(v, *vops) # Create a new intermediate for each subexpression # except boolean conditions and its childs @@ -387,7 +370,7 @@ def generate_partition(self, symbol, F, mode): intermediates.append(L.Assign(vaccess, vexpr)) else: scalar_type = self.backend.access.options["scalar_type"] - vaccess = L.Symbol("%s_%d" % (symbol.name, j)) + vaccess = L.Symbol("%s_%d" % (symbol.name, j), dtype=L.DataType.SCALAR) intermediates.append(L.VariableDecl(f"const {scalar_type}", vaccess, vexpr)) # Store access node for future reference @@ -405,7 +388,6 @@ def generate_partition(self, symbol, F, mode): if intermediates: if use_symbol_array: - scalar_type = self.backend.access.options["scalar_type"] - parts += [L.ArrayDecl(scalar_type, symbol, len(intermediates))] + parts += [L.ArrayDecl(symbol, sizes=len(intermediates))] parts += intermediates return parts diff --git a/ffcx/codegeneration/geometry.py b/ffcx/codegeneration/geometry.py index 271a438fa..bc68375bb 100644 --- a/ffcx/codegeneration/geometry.py +++ b/ffcx/codegeneration/geometry.py @@ -5,31 +5,31 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np - +import ffcx.codegeneration.lnodes as L import basix -def write_table(L, tablename, cellname): +def write_table(tablename, cellname): if tablename == "facet_edge_vertices": - return facet_edge_vertices(L, tablename, cellname) + return facet_edge_vertices(tablename, cellname) if tablename == "reference_facet_jacobian": - return reference_facet_jacobian(L, tablename, cellname) + return reference_facet_jacobian(tablename, cellname) if tablename == "reference_cell_volume": - return reference_cell_volume(L, tablename, cellname) + return reference_cell_volume(tablename, cellname) if tablename == "reference_facet_volume": - return reference_facet_volume(L, tablename, cellname) + return reference_facet_volume(tablename, cellname) if tablename == "reference_edge_vectors": - return reference_edge_vectors(L, tablename, cellname) + return reference_edge_vectors(tablename, cellname) if tablename == "facet_reference_edge_vectors": - return facet_reference_edge_vectors(L, tablename, cellname) + return facet_reference_edge_vectors(tablename, cellname) if tablename == "reference_facet_normals": - return reference_facet_normals(L, tablename, cellname) + return reference_facet_normals(tablename, cellname) if tablename == "facet_orientation": - return facet_orientation(L, tablename, cellname) + return facet_orientation(tablename, cellname) raise ValueError(f"Unknown geometry table name: {tablename}") -def facet_edge_vertices(L, tablename, cellname): +def facet_edge_vertices(tablename, cellname): celltype = getattr(basix.CellType, cellname) topology = basix.topology(celltype) triangle_edges = basix.topology(basix.CellType.triangle)[1] @@ -48,40 +48,45 @@ def facet_edge_vertices(L, tablename, cellname): raise ValueError("Only triangular and quadrilateral faces supported.") out = np.array(edge_vertices, dtype=int) - return L.ArrayDecl("static const unsigned int", f"{cellname}_{tablename}", out.shape, out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.INT) + return L.ArrayDecl(symbol, values=out, const=True) -def reference_facet_jacobian(L, tablename, cellname, type: str): +def reference_facet_jacobian(tablename, cellname): celltype = getattr(basix.CellType, cellname) out = basix.cell.facet_jacobians(celltype) - return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.ArrayDecl(symbol, values=out, const=True) -def reference_cell_volume(L, tablename, cellname, type: str): +def reference_cell_volume(tablename, cellname): celltype = getattr(basix.CellType, cellname) out = basix.cell.volume(celltype) - return L.VariableDecl(f"static const {type}", f"{cellname}_{tablename}", out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.VariableDecl(symbol, out) -def reference_facet_volume(L, tablename, cellname, type: str): +def reference_facet_volume(tablename, cellname): celltype = getattr(basix.CellType, cellname) volumes = basix.cell.facet_reference_volumes(celltype) for i in volumes[1:]: if not np.isclose(i, volumes[0]): raise ValueError("Reference facet volume not supported for this cell type.") - return L.VariableDecl(f"static const {type}", f"{cellname}_{tablename}", volumes[0]) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.VariableDecl(symbol, volumes[0]) -def reference_edge_vectors(L, tablename, cellname, type: str): +def reference_edge_vectors(tablename, cellname): celltype = getattr(basix.CellType, cellname) topology = basix.topology(celltype) geometry = basix.geometry(celltype) edge_vectors = [geometry[j] - geometry[i] for i, j in topology[1]] out = np.array(edge_vectors[cellname]) - return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.ArrayDecl(symbol, values=out, const=True) -def facet_reference_edge_vectors(L, tablename, cellname, type: str): +def facet_reference_edge_vectors(tablename, cellname): celltype = getattr(basix.CellType, cellname) topology = basix.topology(celltype) geometry = basix.geometry(celltype) @@ -101,16 +106,19 @@ def facet_reference_edge_vectors(L, tablename, cellname, type: str): raise ValueError("Only triangular and quadrilateral faces supported.") out = np.array(edge_vectors) - return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.ArrayDecl(symbol, values=out, const=True) -def reference_facet_normals(L, tablename, cellname, type: str): +def reference_facet_normals(tablename, cellname): celltype = getattr(basix.CellType, cellname) out = basix.cell.facet_outward_normals(celltype) - return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.ArrayDecl(symbol, values=out, const=True) -def facet_orientation(L, tablename, cellname, type: str): +def facet_orientation(tablename, cellname): celltype = getattr(basix.CellType, cellname) out = basix.cell.facet_orientations(celltype) - return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", len(out), out) + symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) + return L.ArrayDecl(symbol, values=out, const=True) diff --git a/ffcx/codegeneration/integral_generator.py b/ffcx/codegeneration/integral_generator.py index c58111737..452437f7e 100644 --- a/ffcx/codegeneration/integral_generator.py +++ b/ffcx/codegeneration/integral_generator.py @@ -193,7 +193,7 @@ def generate_geometry_tables(self): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c)) + parts.append(geometry.write_table(ufl_geometry[i], c)) return parts From d7bcbafe1d5ccbdc1d53146ef107085147058373 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 15:25:44 +0100 Subject: [PATCH 05/13] Remove flop count test --- test/test_flops.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) delete mode 100644 test/test_flops.py diff --git a/test/test_flops.py b/test/test_flops.py deleted file mode 100644 index 9ac99ad77..000000000 --- a/test/test_flops.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (C) 2021 Igor A. Baratta -# -# This file is part of FFCx. (https://www.fenicsproject.org) -# -# SPDX-License-Identifier: LGPL-3.0-or-later - - -import ufl -import basix.ufl -from ffcx.codegeneration.flop_count import count_flops - - -def create_form(degree): - mesh = ufl.Mesh(basix.ufl.element("Lagrange", "triangle", 1, rank=1)) - element = basix.ufl.element("Lagrange", "triangle", degree) - V = ufl.FunctionSpace(mesh, element) - - u = ufl.TrialFunction(V) - v = ufl.TestFunction(V) - - return ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx + ufl.inner(u, v) * ufl.ds - - -def test_flops(): - k1, k2 = 2, 4 - a1 = create_form(k1) - a2 = create_form(k2) - - dofs1 = (k1 + 1.) * (k1 + 2.) / 2. - dofs2 = (k2 + 1.) * (k2 + 2.) / 2. - - flops_1 = count_flops(a1) - assert len(flops_1) == 2 - - flops_2 = count_flops(a2) - assert len(flops_2) == 2 - - r = sum(flops_2, 0.) / sum(flops_1, 0.) - - assert r > (dofs2**2 / dofs1**2) From e8d3b23338c23976cc8425875117ae27a5b93194 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Tue, 29 Aug 2023 12:19:30 +0100 Subject: [PATCH 06/13] Add demo to test complex literal --- demo/ComplexPoisson.py | 39 +++++++++++++++++++++++++++++++++++++++ demo/test_demos.py | 6 +++++- 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 demo/ComplexPoisson.py diff --git a/demo/ComplexPoisson.py b/demo/ComplexPoisson.py new file mode 100644 index 000000000..731ec9c09 --- /dev/null +++ b/demo/ComplexPoisson.py @@ -0,0 +1,39 @@ +# Copyright (C) 2023 Chris Richardson +# +# This file is part of FFCx. +# +# FFCx is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# FFCx is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with FFCx. If not, see . +# +# The bilinear form a(u, v) and linear form L(v) for +# Poisson's equation using bilinear elements on bilinear mesh geometry. +import basix.ufl +from ufl import (Coefficient, FunctionSpace, Mesh, TestFunction, TrialFunction, + dx, grad, inner) + +coords = basix.ufl.element("P", "triangle", 2, rank=1) +mesh = Mesh(coords) +dx = dx(mesh) + +element = basix.ufl.element("P", mesh.ufl_cell().cellname(), 2) +space = FunctionSpace(mesh, element) + +u = TrialFunction(space) +v = TestFunction(space) +f = Coefficient(space) + +# Test literal complex number in form +k = 3.213 + 1.023j + +a = k * inner(grad(u), grad(v)) * dx +L = inner(k * f, v) * dx diff --git a/demo/test_demos.py b/demo/test_demos.py index d28e394e4..8cf199619 100644 --- a/demo/test_demos.py +++ b/demo/test_demos.py @@ -22,8 +22,12 @@ def test_demo(file): # Skip demos that use elements not yet implemented in Basix pytest.skip() + opts = "" + if "Complex" in file: + opts = '--scalar_type "double _Complex"' + extra_flags = "-Wunused-variable -Werror -fPIC " - assert os.system(f"cd {demo_dir} && ffcx {file}.py") == 0 + assert os.system(f"cd {demo_dir} && ffcx {opts} {file}.py") == 0 assert os.system(f"cd {demo_dir} && " "CPATH=../ffcx/codegeneration/ " f"gcc -I/usr/include/python{sys.version_info.major}.{sys.version_info.minor} {extra_flags}" From cdf41a5811190e5ed507ac063675a2b066a26684 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Tue, 29 Aug 2023 12:48:05 +0100 Subject: [PATCH 07/13] Formatting tweaks for complex numbers --- ffcx/codegeneration/C/c_implementation.py | 39 +++++++++++++++-------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/ffcx/codegeneration/C/c_implementation.py b/ffcx/codegeneration/C/c_implementation.py index f735673c8..36935e19b 100644 --- a/ffcx/codegeneration/C/c_implementation.py +++ b/ffcx/codegeneration/C/c_implementation.py @@ -6,7 +6,8 @@ import warnings import ffcx.codegeneration.lnodes as L -from ffcx.codegeneration.utils import scalar_to_value_type +from ffcx.codegeneration.utils import scalar_to_value_type, cdtype_to_numpy +import numpy as np math_table = { "double": { @@ -137,20 +138,29 @@ } -def build_initializer_lists(values): - arr = "{" - if len(values.shape) == 1: - arr += ", ".join(str(v) for v in values) - elif len(values.shape) > 1: - arr += ",\n ".join(build_initializer_lists(v) for v in values) - arr += "}" - return arr - - class CFormatter(object): - def __init__(self, scalar) -> None: + def __init__(self, scalar, precision=None) -> None: self.scalar_type = scalar self.real_type = scalar_to_value_type(scalar) + if precision is None: + np_type = cdtype_to_numpy(self.real_type) + self.precision = np.finfo(np_type).precision + 1 + + def _format_float(self, x): + prec = self.precision + if isinstance(x, complex): + return "({:.{prec}}+I*{:.{prec}})".format(x.real, x.imag, prec=prec) + else: + return "{:.{prec}}".format(x, prec=prec) + + def _build_initializer_lists(self, values): + arr = "{" + if len(values.shape) == 1: + arr += ", ".join(self._format_float(v) for v in values) + elif len(values.shape) > 1: + arr += ",\n ".join(self._build_initializer_lists(v) for v in values) + arr += "}" + return arr def format_statement_list(self, slist) -> str: return "".join(self.c_format(s) for s in slist.statements) @@ -175,7 +185,7 @@ def format_array_decl(self, arr) -> str: assert arr.const is False return f"{typename} {symbol}{dims};\n" - vals = build_initializer_lists(arr.values) + vals = self._build_initializer_lists(arr.values) cstr = "static const " if arr.const else "" return f"{cstr}{typename} {symbol}{dims} = {vals};\n" @@ -229,7 +239,8 @@ def format_not(self, val) -> str: return f"{val.op}({arg})" def format_literal_float(self, val) -> str: - return f"{val.value}" + value = self._format_float(val.value) + return f"{value}" def format_literal_int(self, val) -> str: return f"{val.value}" From 9afe210967fa8cd94f97aaae4aed83f103afdf17 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Tue, 29 Aug 2023 14:37:36 +0100 Subject: [PATCH 08/13] Fixes for obscure functionality (which is probably broken in main) --- demo/CellGeometry.py | 8 +++++++- ffcx/codegeneration/C/c_implementation.py | 14 ++++++++++---- ffcx/codegeneration/access.py | 2 +- ffcx/codegeneration/geometry.py | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/demo/CellGeometry.py b/demo/CellGeometry.py index 57cd9e88f..8eea59f0f 100644 --- a/demo/CellGeometry.py +++ b/demo/CellGeometry.py @@ -3,7 +3,8 @@ # A functional M involving a bunch of cell geometry quantities. import basix.ufl from ufl import (CellVolume, Circumradius, Coefficient, FacetArea, FacetNormal, - SpatialCoordinate, ds, dx, tetrahedron) + SpatialCoordinate, ds, dx, tetrahedron, TrialFunction) +from ufl.geometry import FacetEdgeVectors cell = tetrahedron V = basix.ufl.element("P", cell.cellname(), 1) @@ -17,3 +18,8 @@ area = FacetArea(cell) M = u * (x[0] * vol * rad) * dx + u * (x[0] * vol * rad * area) * ds # + u*area*avg(n[0]*x[0]*vol*rad)*dS + +# Test some obscure functionality +fev = FacetEdgeVectors(cell) +v = TrialFunction(V) +L = fev[0, 0] * v * ds diff --git a/ffcx/codegeneration/C/c_implementation.py b/ffcx/codegeneration/C/c_implementation.py index 36935e19b..cd0e86244 100644 --- a/ffcx/codegeneration/C/c_implementation.py +++ b/ffcx/codegeneration/C/c_implementation.py @@ -145,18 +145,22 @@ def __init__(self, scalar, precision=None) -> None: if precision is None: np_type = cdtype_to_numpy(self.real_type) self.precision = np.finfo(np_type).precision + 1 + else: + assert isinstance(precision, int) + self.precision = precision - def _format_float(self, x): + def _format_number(self, x): prec = self.precision if isinstance(x, complex): return "({:.{prec}}+I*{:.{prec}})".format(x.real, x.imag, prec=prec) - else: + elif isinstance(x, float): return "{:.{prec}}".format(x, prec=prec) + return str(x) def _build_initializer_lists(self, values): arr = "{" if len(values.shape) == 1: - arr += ", ".join(self._format_float(v) for v in values) + arr += ", ".join(self._format_number(v) for v in values) elif len(values.shape) > 1: arr += ",\n ".join(self._build_initializer_lists(v) for v in values) arr += "}" @@ -176,6 +180,8 @@ def format_array_decl(self, arr) -> str: typename = self.scalar_type elif dtype == L.DataType.REAL: typename = self.real_type + elif dtype == L.DataType.INT: + typename = "int" else: raise ValueError(f"Invalid dtype: {dtype}") @@ -239,7 +245,7 @@ def format_not(self, val) -> str: return f"{val.op}({arg})" def format_literal_float(self, val) -> str: - value = self._format_float(val.value) + value = self._format_number(val.value) return f"{value}" def format_literal_int(self, val) -> str: diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index 2b1658d58..b047251ee 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -333,7 +333,7 @@ def facet_edge_vectors(self, e, mt, tabledata, num_points): # Get edge vertices facet = self.symbols.entity("facet", mt.restriction) facet_edge = mt.component[0] - facet_edge_vertices = L.Symbol(f"{cellname}_facet_edge_vertices") + facet_edge_vertices = L.Symbol(f"{cellname}_facet_edge_vertices", dtype=L.DataType.INT) vertex0 = facet_edge_vertices[facet][facet_edge][0] vertex1 = facet_edge_vertices[facet][facet_edge][1] diff --git a/ffcx/codegeneration/geometry.py b/ffcx/codegeneration/geometry.py index bc68375bb..b52f0f473 100644 --- a/ffcx/codegeneration/geometry.py +++ b/ffcx/codegeneration/geometry.py @@ -81,7 +81,7 @@ def reference_edge_vectors(tablename, cellname): topology = basix.topology(celltype) geometry = basix.geometry(celltype) edge_vectors = [geometry[j] - geometry[i] for i, j in topology[1]] - out = np.array(edge_vectors[cellname]) + out = np.array(edge_vectors) symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL) return L.ArrayDecl(symbol, values=out, const=True) From b4683a5b1021ba6f27e434434ad9d0ee043e3f69 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Tue, 29 Aug 2023 15:01:18 +0100 Subject: [PATCH 09/13] Correction to facet_edge_vertices --- ffcx/codegeneration/geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ffcx/codegeneration/geometry.py b/ffcx/codegeneration/geometry.py index b52f0f473..a5b87f215 100644 --- a/ffcx/codegeneration/geometry.py +++ b/ffcx/codegeneration/geometry.py @@ -41,9 +41,9 @@ def facet_edge_vertices(tablename, cellname): edge_vertices = [] for facet in topology[-2]: if len(facet) == 3: - edge_vertices += [[facet[i] for i in edge] for edge in triangle_edges] + edge_vertices += [[[facet[i] for i in edge] for edge in triangle_edges]] elif len(facet) == 4: - edge_vertices += [[facet[i] for i in edge] for edge in quadrilateral_edges] + edge_vertices += [[[facet[i] for i in edge] for edge in quadrilateral_edges]] else: raise ValueError("Only triangular and quadrilateral faces supported.") From 5198ba64ff90652ed0d5f4dd4bd895354c70a0ef Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Thu, 31 Aug 2023 17:32:33 +0100 Subject: [PATCH 10/13] Make precision work --- demo/MetaData.py | 2 ++ ffcx/analysis.py | 2 +- ffcx/codegeneration/C/c_implementation.py | 6 +++--- ffcx/codegeneration/C/integrals.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/demo/MetaData.py b/demo/MetaData.py index bc94b7465..3de96f592 100644 --- a/demo/MetaData.py +++ b/demo/MetaData.py @@ -32,3 +32,5 @@ + inner(c, c) * inner(grad(u), grad(v)) * dx(1, degree=4)\ + inner(c, c) * inner(grad(u), grad(v)) * dx(1, degree=2)\ + inner(grad(u), grad(v)) * dx(1, degree=-1) + +L = v * dx(0, metadata={"precision": 1}) diff --git a/ffcx/analysis.py b/ffcx/analysis.py index 51ea1610f..0e55a6e19 100644 --- a/ffcx/analysis.py +++ b/ffcx/analysis.py @@ -202,7 +202,7 @@ def _analyze_form(form: ufl.form.Form, options: typing.Dict) -> ufl.algorithms.f p = precisions.pop() elif len(precisions) == 0: # Default precision - p = np.finfo("double").precision + 1 # == 16 + p = None else: raise RuntimeError("Only one precision allowed within integrals grouped by subdomain.") diff --git a/ffcx/codegeneration/C/c_implementation.py b/ffcx/codegeneration/C/c_implementation.py index cd0e86244..f6d18bacd 100644 --- a/ffcx/codegeneration/C/c_implementation.py +++ b/ffcx/codegeneration/C/c_implementation.py @@ -150,11 +150,11 @@ def __init__(self, scalar, precision=None) -> None: self.precision = precision def _format_number(self, x): - prec = self.precision + p = self.precision if isinstance(x, complex): - return "({:.{prec}}+I*{:.{prec}})".format(x.real, x.imag, prec=prec) + return f"({x.real:.{p}}+I*{x.imag:.{p}})" elif isinstance(x, float): - return "{:.{prec}}".format(x, prec=prec) + return f"{x:.{p}}" return str(x) def _build_initializer_lists(self, values): diff --git a/ffcx/codegeneration/C/integrals.py b/ffcx/codegeneration/C/integrals.py index c6916591a..386e3ed40 100644 --- a/ffcx/codegeneration/C/integrals.py +++ b/ffcx/codegeneration/C/integrals.py @@ -36,7 +36,7 @@ def generator(ir, options): parts = ig.generate() # Format code as string - CF = CFormatter(options["scalar_type"]) + CF = CFormatter(options["scalar_type"], ir.precision) body = CF.c_format(parts) # Generate generic FFCx code snippets and add specific parts From 7a5a39642b868a9f987aa7a253a15cf43b63ebd5 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Mon, 4 Sep 2023 09:49:31 +0100 Subject: [PATCH 11/13] Remove unused files --- ffcx/codegeneration/C/format_lines.py | 60 ----- ffcx/codegeneration/C/format_value.py | 59 ----- ffcx/codegeneration/C/precedence.py | 61 ----- ffcx/codegeneration/C/ufl_to_cnodes.py | 295 ------------------------- 4 files changed, 475 deletions(-) delete mode 100644 ffcx/codegeneration/C/format_lines.py delete mode 100644 ffcx/codegeneration/C/format_value.py delete mode 100644 ffcx/codegeneration/C/precedence.py delete mode 100644 ffcx/codegeneration/C/ufl_to_cnodes.py diff --git a/ffcx/codegeneration/C/format_lines.py b/ffcx/codegeneration/C/format_lines.py deleted file mode 100644 index 1fcfaee92..000000000 --- a/ffcx/codegeneration/C/format_lines.py +++ /dev/null @@ -1,60 +0,0 @@ -# This file is part of FFCx.(https://www.fenicsproject.org) -# -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Tools for indentation-aware code string stitching. - -When formatting an AST into a string, it's better to collect lists of -snippets and then join them than adding the pieces continually, which -gives O(n^2) behaviour w.r.t. AST size n. - -""" - - -class Indented(object): - """Class to mark a collection of snippets for indentation. - - This way nested indentations can be handled by adding the prefix - spaces only once to each line instead of splitting and indenting - substrings repeatedly. - - """ - - # Try to keep memory overhead low: - __slots__ = ("body", ) - - def __init__(self, body): - # Body can be any valid snippet format - self.body = body - - -def iter_indented_lines(snippets, level=0): - """Iterate over indented string lines from a snippets data structure. - - The snippets object can be built recursively using the following types: - - - str: Split and yield as one line at a time indented to the appropriate level. - - - Indented: Yield the lines within this object indented by one level. - - - tuple,list: Yield lines from recursive application of this function to list items. - - """ - tabsize = 2 - indentation = ' ' * (tabsize * level) - if isinstance(snippets, str): - for line in snippets.split("\n"): - yield indentation + line - elif isinstance(snippets, Indented): - for line in iter_indented_lines(snippets.body, level + 1): - yield line - elif isinstance(snippets, (tuple, list)): - for part in snippets: - for line in iter_indented_lines(part, level): - yield line - else: - raise RuntimeError("Unexpected type %s:\n%s" % (type(snippets), str(snippets))) - - -def format_indented_lines(snippets, level=0): - """Format recursive sequences of indented lines as one string.""" - return "\n".join(iter_indented_lines(snippets, level)) diff --git a/ffcx/codegeneration/C/format_value.py b/ffcx/codegeneration/C/format_value.py deleted file mode 100644 index 0f057a3df..000000000 --- a/ffcx/codegeneration/C/format_value.py +++ /dev/null @@ -1,59 +0,0 @@ -# This file is part of FFCx.(https://www.fenicsproject.org) -# -# SPDX-License-Identifier: LGPL-3.0-or-later - -import numbers -import re - -_subs = ( - # Remove 0s after e+ or e- - (re.compile(r"e[\+]0*(.)"), r"e\1"), - (re.compile(r"e[\-]0*(.)"), r"e-\1"), -) - - -def format_float(x, precision=None): - """Format a float value according to given precision.""" - global _subs - - if precision: - if isinstance(x, complex): - s = "({:.{prec}}+I*{:.{prec}})".format(x.real, x.imag, prec=precision) - elif isinstance(x, float): - s = "{:.{prec}}".format(x, prec=precision) - else: - s = "{:.{prec}}".format(float(x), prec=precision) - else: - s = repr(float(x)) - for r, v in _subs: - s = r.sub(v, s) - return s - - -def format_int(x, precision=None): - return str(x) - - -def format_value(value, precision=None): - """Format a literal value as string. - - - float: Formatted according to current precision configuration. - - - int: Formatted as regular base 10 int literal. - - - str: Wrapped in "quotes". - - """ - if isinstance(value, numbers.Real): - return format_float(float(value), precision=precision) - elif isinstance(value, numbers.Integral): - return format_int(int(value)) - elif isinstance(value, str): - # FIXME: Is this ever used? - assert '"' not in value - return '"' + value + '"' - elif hasattr(value, "ce_format"): - return value.ce_format() - else: - raise RuntimeError("Unexpected type %s:\n%s" % (type(value), - str(value))) diff --git a/ffcx/codegeneration/C/precedence.py b/ffcx/codegeneration/C/precedence.py deleted file mode 100644 index a54c8c0f8..000000000 --- a/ffcx/codegeneration/C/precedence.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (C) 2011-2017 Martin Sandve Alnæs -# -# This file is part of FFCx.(https://www.fenicsproject.org) -# -# SPDX-License-Identifier: LGPL-3.0-or-later - - -class PRECEDENCE: - """An enum-like class for C operator precedence levels.""" - - HIGHEST = 0 - LITERAL = 0 - SYMBOL = 0 - - # SCOPE = 1 - - POST_INC = 2 - POST_DEC = 2 - CALL = 2 - SUBSCRIPT = 2 - # MEMBER = 2 - # PTR_MEMBER = 2 - - PRE_INC = 3 - PRE_DEC = 3 - NOT = 3 - BIT_NOT = 3 - POS = 3 - NEG = 3 - DEREFERENCE = 3 - ADDRESSOF = 3 - SIZEOF = 3 - - MUL = 4 - DIV = 4 - MOD = 4 - - ADD = 5 - SUB = 5 - - BITSHIFT = 6 - - LT = 7 - LE = 7 - GT = 7 - GE = 7 - - EQ = 8 - NE = 8 - - BITAND = 9 - - AND = 11 - OR = 12 - - CONDITIONAL = 13 - ASSIGN = 13 - - # COMMA = 14 - - LOWEST = 15 diff --git a/ffcx/codegeneration/C/ufl_to_cnodes.py b/ffcx/codegeneration/C/ufl_to_cnodes.py deleted file mode 100644 index 85d9faa94..000000000 --- a/ffcx/codegeneration/C/ufl_to_cnodes.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (C) 2011-2017 Martin Sandve Alnæs -# -# This file is part of FFCx.(https://www.fenicsproject.org) -# -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Tools for C/C++ expression formatting.""" - -import logging - -import ufl - -logger = logging.getLogger("ffcx") - -# Table of handled math functions for different scalar types - -math_table = {'double': {'sqrt': 'sqrt', - 'abs': 'fabs', - 'cos': 'cos', - 'sin': 'sin', - 'tan': 'tan', - 'acos': 'acos', - 'asin': 'asin', - 'atan': 'atan', - 'cosh': 'cosh', - 'sinh': 'sinh', - 'tanh': 'tanh', - 'acosh': 'acosh', - 'asinh': 'asinh', - 'atanh': 'atanh', - 'power': 'pow', - 'exp': 'exp', - 'ln': 'log', - 'erf': 'erf', - 'atan_2': 'atan2', - 'min_value': 'fmin', - 'max_value': 'fmax'}, - - 'float': {'sqrt': 'sqrtf', - 'abs': 'fabsf', - 'cos': 'cosf', - 'sin': 'sinf', - 'tan': 'tanf', - 'acos': 'acosf', - 'asin': 'asinf', - 'atan': 'atanf', - 'cosh': 'coshf', - 'sinh': 'sinhf', - 'tanh': 'tanhf', - 'acosh': 'acoshf', - 'asinh': 'asinhf', - 'atanh': 'atanhf', - 'power': 'powf', - 'exp': 'expf', - 'ln': 'logf', - 'erf': 'erff', - 'atan_2': 'atan2f', - 'min_value': 'fminf', - 'max_value': 'fmaxf'}, - - 'long double': {'sqrt': 'sqrtl', - 'abs': 'fabsl', - 'cos': 'cosl', - 'sin': 'sinl', - 'tan': 'tanl', - 'acos': 'acosl', - 'asin': 'asinl', - 'atan': 'atanl', - 'cosh': 'coshl', - 'sinh': 'sinhl', - 'tanh': 'tanhl', - 'acosh': 'acoshl', - 'asinh': 'asinhl', - 'atanh': 'atanhl', - 'power': 'powl', - 'exp': 'expl', - 'ln': 'logl', - 'erf': 'erfl', - 'atan_2': 'atan2l', - 'min_value': 'fminl', - 'max_value': 'fmaxl'}, - - 'double _Complex': {'sqrt': 'csqrt', - 'abs': 'cabs', - 'cos': 'ccos', - 'sin': 'csin', - 'tan': 'ctan', - 'acos': 'cacos', - 'asin': 'casin', - 'atan': 'catan', - 'cosh': 'ccosh', - 'sinh': 'csinh', - 'tanh': 'ctanh', - 'acosh': 'cacosh', - 'asinh': 'casinh', - 'atanh': 'catanh', - 'power': 'cpow', - 'exp': 'cexp', - 'ln': 'clog', - 'real': 'creal', - 'imag': 'cimag', - 'conj': 'conj', - 'max_value': 'fmax', - 'min_value': 'fmin'}, - - 'float _Complex': {'sqrt': 'csqrtf', - 'abs': 'cabsf', - 'cos': 'ccosf', - 'sin': 'csinf', - 'tan': 'ctanf', - 'acos': 'cacosf', - 'asin': 'casinf', - 'atan': 'catanf', - 'cosh': 'ccoshf', - 'sinh': 'csinhf', - 'tanh': 'ctanhf', - 'acosh': 'cacoshf', - 'asinh': 'casinhf', - 'atanh': 'catanhf', - 'power': 'cpowf', - 'exp': 'cexpf', - 'ln': 'clogf', - 'real': 'crealf', - 'imag': 'cimagf', - 'conj': 'conjf', - 'max_value': 'fmaxf', - 'min_value': 'fminf'}} - - -class UFL2CNodesTranslatorCpp(object): - """UFL to CNodes translator class.""" - - def __init__(self, language, scalar_type="double"): - self.L = language - self.force_floats = False - self.enable_strength_reduction = False - self.scalar_type = scalar_type - - # Lookup table for handler to call when the "get" method (below) is - # called, depending on the first argument type. - self.call_lookup = {ufl.constantvalue.IntValue: self.int_value, - ufl.constantvalue.FloatValue: self.float_value, - ufl.constantvalue.ComplexValue: self.complex_value, - ufl.constantvalue.Zero: self.zero, - ufl.algebra.Product: self.product, - ufl.algebra.Sum: self.sum, - ufl.algebra.Division: self.division, - ufl.algebra.Abs: self._cmath, - ufl.algebra.Power: self._cmath, - ufl.algebra.Real: self._cmath, - ufl.algebra.Imag: self._cmath, - ufl.algebra.Conj: self._cmath, - ufl.classes.GT: self.gt, - ufl.classes.GE: self.ge, - ufl.classes.EQ: self.eq, - ufl.classes.NE: self.ne, - ufl.classes.LT: self.lt, - ufl.classes.LE: self.le, - ufl.classes.AndCondition: self.and_condition, - ufl.classes.OrCondition: self.or_condition, - ufl.classes.NotCondition: self.not_condition, - ufl.classes.Conditional: self.conditional, - ufl.classes.MinValue: self._cmath, - ufl.classes.MaxValue: self._cmath, - ufl.mathfunctions.Sqrt: self._cmath, - ufl.mathfunctions.Ln: self._cmath, - ufl.mathfunctions.Exp: self._cmath, - ufl.mathfunctions.Cos: self._cmath, - ufl.mathfunctions.Sin: self._cmath, - ufl.mathfunctions.Tan: self._cmath, - ufl.mathfunctions.Cosh: self._cmath, - ufl.mathfunctions.Sinh: self._cmath, - ufl.mathfunctions.Tanh: self._cmath, - ufl.mathfunctions.Acos: self._cmath, - ufl.mathfunctions.Asin: self._cmath, - ufl.mathfunctions.Atan: self._cmath, - ufl.mathfunctions.Erf: self._cmath, - ufl.mathfunctions.Atan2: self._cmath, - ufl.mathfunctions.MathFunction: self.math_function, - ufl.mathfunctions.BesselJ: self.bessel_j, - ufl.mathfunctions.BesselY: self.bessel_y} - - def get(self, o, *args): - # Call appropriate handler, depending on the type of o - otype = type(o) - if otype in self.call_lookup: - return self.call_lookup[otype](o, *args) - else: - raise RuntimeError(f"Missing C formatting rule for expr type {otype}.") - - def expr(self, o, *args): - """Raise generic fallback with error message for missing rules.""" - raise RuntimeError(f"Missing C formatting rule for expr type {o._ufl_class_}.") - - # === Formatting rules for scalar literals === - - def zero(self, o): - return self.L.LiteralFloat(0.0) - - def float_value(self, o): - return self.L.LiteralFloat(float(o)) - - def int_value(self, o): - if self.force_floats: - return self.float_value(o) - return self.L.LiteralInt(int(o)) - - def complex_value(self, o): - return self.L.LiteralFloat(o.value()) - - # === Formatting rules for arithmetic operators === - - def sum(self, o, a, b): - return self.L.Add(a, b) - - def product(self, o, a, b): - return self.L.Mul(a, b) - - def division(self, o, a, b): - if self.enable_strength_reduction: - return self.L.Mul(a, self.L.Div(1.0, b)) - else: - return self.L.Div(a, b) - - # === Formatting rules for conditional expressions === - - def conditional(self, o, c, t, f): - return self.L.Conditional(c, t, f) - - def eq(self, o, a, b): - return self.L.EQ(a, b) - - def ne(self, o, a, b): - return self.L.NE(a, b) - - def le(self, o, a, b): - return self.L.LE(a, b) - - def ge(self, o, a, b): - return self.L.GE(a, b) - - def lt(self, o, a, b): - return self.L.LT(a, b) - - def gt(self, o, a, b): - return self.L.GT(a, b) - - def and_condition(self, o, a, b): - return self.L.And(a, b) - - def or_condition(self, o, a, b): - return self.L.Or(a, b) - - def not_condition(self, o, a): - return self.L.Not(a) - - # === Formatting rules for cmath functions === - - def math_function(self, o, op): - # Fallback for unhandled MathFunction subclass: - # attempting to just call it. - return self.L.Call(o._name, op) - - def _cmath(self, o, *args): - k = o._ufl_handler_name_ - try: - name = math_table[self.scalar_type].get(k) - except Exception as e: - raise type(e)("Math function not found:", self.scalar_type, k) - if name is None: - raise RuntimeError("Not supported in current scalar mode") - return self.L.Call(name, args) - - # === Formatting rules for bessel functions === - # Some Bessel functions exist in gcc, as XSI extensions - # but not all. - - def bessel_j(self, o, n, v): - assert "complex" not in self.scalar_type - n = int(float(n)) - if n == 0: - return self.L.Call("j0", v) - elif n == 1: - return self.L.Call("j1", v) - else: - return self.L.Call("jn", (n, v)) - - def bessel_y(self, o, n, v): - assert "complex" not in self.scalar_type - n = int(float(n)) - if n == 0: - return self.L.Call("y0", v) - elif n == 1: - return self.L.Call("y1", v) - else: - return self.L.Call("yn", (n, v)) From b04591bd6ca60dbdf832ebbdbd420fdb25601dda Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Mon, 4 Sep 2023 14:28:19 +0100 Subject: [PATCH 12/13] Get rid of UFL2LNodes class --- ffcx/codegeneration/expression_generator.py | 7 +- ffcx/codegeneration/integral_generator.py | 7 +- ffcx/codegeneration/lnodes.py | 120 ++++++++++---------- 3 files changed, 63 insertions(+), 71 deletions(-) diff --git a/ffcx/codegeneration/expression_generator.py b/ffcx/codegeneration/expression_generator.py index 54af21215..c0a81a36e 100644 --- a/ffcx/codegeneration/expression_generator.py +++ b/ffcx/codegeneration/expression_generator.py @@ -27,7 +27,6 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend): self.ir = ir self.backend = backend - self.ufl_to_language = L.UFL2LNodes() self.scope: Dict[Any, LNode] = {} self._ufl_names: Set[Any] = set() self.symbol_counters: DefaultDict[Any, int] = collections.defaultdict(int) @@ -296,7 +295,7 @@ def new_temp_symbol(self, basename): def get_var(self, v): if v._ufl_is_literal_: - return self.ufl_to_language.get(v) + return L.ufl_to_lnodes(v) f = self.scope.get(v) return f @@ -315,7 +314,7 @@ def generate_partition(self, symbol, F, mode): mt = attr.get('mt') if v._ufl_is_literal_: - vaccess = self.ufl_to_language.get(v) + vaccess = L.ufl_to_lnodes(v) elif mt is not None: # All finite element based terminals have table data, as well # as some, but not all, of the symbolic geometric terminals @@ -344,7 +343,7 @@ def generate_partition(self, symbol, F, mode): # Mapping UFL operator to target language self._ufl_names.add(v._ufl_handler_name_) - vexpr = self.ufl_to_language.get(v, *vops) + vexpr = L.ufl_to_lnodes(v, *vops) # Create a new intermediate for each subexpression # except boolean conditions and its childs diff --git a/ffcx/codegeneration/integral_generator.py b/ffcx/codegeneration/integral_generator.py index 452437f7e..2713f92ad 100644 --- a/ffcx/codegeneration/integral_generator.py +++ b/ffcx/codegeneration/integral_generator.py @@ -29,7 +29,6 @@ def __init__(self, ir, backend): # - definitions: for defining backend specific variables # - access: for accessing backend specific variables self.backend = backend - self.ufl_to_language = L.UFL2LNodes() # Set of operator names code has been generated for, used in the # end for selecting necessary includes @@ -75,7 +74,7 @@ def get_var(self, quadrature_rule, v): Returns the LNodes expression to access the value in the code. """ if v._ufl_is_literal_: - return self.ufl_to_language.get(v) + return L.ufl_to_lnodes(v) f = self.scopes[quadrature_rule].get(v) if f is None: f = self.scopes[None].get(v) @@ -294,7 +293,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # cache if not self.get_var(quadrature_rule, v): if v._ufl_is_literal_: - vaccess = self.ufl_to_language.get(v) + vaccess = L.ufl_to_lnodes(v) elif mt is not None: # All finite element based terminals have table # data, as well as some, but not all, of the @@ -324,7 +323,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # Mapping UFL operator to target language self._ufl_names.add(v._ufl_handler_name_) - vexpr = self.ufl_to_language.get(v, *vops) + vexpr = L.ufl_to_lnodes(v, *vops) # Create a new intermediate for each subexpression # except boolean conditions and its childs diff --git a/ffcx/codegeneration/lnodes.py b/ffcx/codegeneration/lnodes.py index 8aba3029e..4c86c7932 100644 --- a/ffcx/codegeneration/lnodes.py +++ b/ffcx/codegeneration/lnodes.py @@ -815,66 +815,60 @@ def as_statement(node): ) -class UFL2LNodes(object): - """UFL to LNodes translator class.""" - - def __init__(self): - self.force_floats = False - self.enable_strength_reduction = False - - # Lookup table for handler to call when the "get" method (below) is - # called, depending on the first argument type. - self.call_lookup = { - ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)), - ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)), - ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()), - ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0), - ufl.algebra.Product: lambda x, a, b: a * b, - ufl.algebra.Sum: lambda x, a, b: a + b, - ufl.algebra.Division: lambda x, a, b: a / b, - ufl.algebra.Abs: self.math_function, - ufl.algebra.Power: self.math_function, - ufl.algebra.Real: self.math_function, - ufl.algebra.Imag: self.math_function, - ufl.algebra.Conj: self.math_function, - ufl.classes.GT: lambda x, a, b: GT(a, b), - ufl.classes.GE: lambda x, a, b: GE(a, b), - ufl.classes.EQ: lambda x, a, b: EQ(a, b), - ufl.classes.NE: lambda x, a, b: NE(a, b), - ufl.classes.LT: lambda x, a, b: LT(a, b), - ufl.classes.LE: lambda x, a, b: LE(a, b), - ufl.classes.AndCondition: lambda x, a, b: And(a, b), - ufl.classes.OrCondition: lambda x, a, b: Or(a, b), - ufl.classes.NotCondition: lambda x, a: Not(a), - ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f), - ufl.classes.MinValue: self.math_function, - ufl.classes.MaxValue: self.math_function, - ufl.mathfunctions.Sqrt: self.math_function, - ufl.mathfunctions.Ln: self.math_function, - ufl.mathfunctions.Exp: self.math_function, - ufl.mathfunctions.Cos: self.math_function, - ufl.mathfunctions.Sin: self.math_function, - ufl.mathfunctions.Tan: self.math_function, - ufl.mathfunctions.Cosh: self.math_function, - ufl.mathfunctions.Sinh: self.math_function, - ufl.mathfunctions.Tanh: self.math_function, - ufl.mathfunctions.Acos: self.math_function, - ufl.mathfunctions.Asin: self.math_function, - ufl.mathfunctions.Atan: self.math_function, - ufl.mathfunctions.Erf: self.math_function, - ufl.mathfunctions.Atan2: self.math_function, - ufl.mathfunctions.MathFunction: self.math_function, - ufl.mathfunctions.BesselJ: self.math_function, - ufl.mathfunctions.BesselY: self.math_function, - } - - def get(self, o, *args): - # Call appropriate handler, depending on the type of o - otype = type(o) - if otype in self.call_lookup: - return self.call_lookup[otype](o, *args) - else: - raise RuntimeError(f"Missing lookup for expr type {otype}.") - - def math_function(self, o, *args): - return MathFunction(o._ufl_handler_name_, args) +def _math_function(op, *args): + return MathFunction(op._ufl_handler_name_, args) + + +# Lookup table for handler to call when the "get" method (below) is +# called, depending on the first argument type. +_ufl_call_lookup = { + ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)), + ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)), + ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()), + ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0), + ufl.algebra.Product: lambda x, a, b: a * b, + ufl.algebra.Sum: lambda x, a, b: a + b, + ufl.algebra.Division: lambda x, a, b: a / b, + ufl.algebra.Abs: _math_function, + ufl.algebra.Power: _math_function, + ufl.algebra.Real: _math_function, + ufl.algebra.Imag: _math_function, + ufl.algebra.Conj: _math_function, + ufl.classes.GT: lambda x, a, b: GT(a, b), + ufl.classes.GE: lambda x, a, b: GE(a, b), + ufl.classes.EQ: lambda x, a, b: EQ(a, b), + ufl.classes.NE: lambda x, a, b: NE(a, b), + ufl.classes.LT: lambda x, a, b: LT(a, b), + ufl.classes.LE: lambda x, a, b: LE(a, b), + ufl.classes.AndCondition: lambda x, a, b: And(a, b), + ufl.classes.OrCondition: lambda x, a, b: Or(a, b), + ufl.classes.NotCondition: lambda x, a: Not(a), + ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f), + ufl.classes.MinValue: _math_function, + ufl.classes.MaxValue: _math_function, + ufl.mathfunctions.Sqrt: _math_function, + ufl.mathfunctions.Ln: _math_function, + ufl.mathfunctions.Exp: _math_function, + ufl.mathfunctions.Cos: _math_function, + ufl.mathfunctions.Sin: _math_function, + ufl.mathfunctions.Tan: _math_function, + ufl.mathfunctions.Cosh: _math_function, + ufl.mathfunctions.Sinh: _math_function, + ufl.mathfunctions.Tanh: _math_function, + ufl.mathfunctions.Acos: _math_function, + ufl.mathfunctions.Asin: _math_function, + ufl.mathfunctions.Atan: _math_function, + ufl.mathfunctions.Erf: _math_function, + ufl.mathfunctions.Atan2: _math_function, + ufl.mathfunctions.MathFunction: _math_function, + ufl.mathfunctions.BesselJ: _math_function, + ufl.mathfunctions.BesselY: _math_function} + + +def ufl_to_lnodes(operator, *args): + # Call appropriate handler, depending on the type of operator + optype = type(operator) + if optype in _ufl_call_lookup: + return _ufl_call_lookup[optype](operator, *args) + else: + raise RuntimeError(f"Missing lookup for expr type {optype}.") From c76ee342f77cf15ff9c59b638d3a83ad552cd071 Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Mon, 4 Sep 2023 14:29:00 +0100 Subject: [PATCH 13/13] Fix documentation --- ffcx/codegeneration/lnodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcx/codegeneration/lnodes.py b/ffcx/codegeneration/lnodes.py index 4c86c7932..463b31892 100644 --- a/ffcx/codegeneration/lnodes.py +++ b/ffcx/codegeneration/lnodes.py @@ -819,7 +819,7 @@ def _math_function(op, *args): return MathFunction(op._ufl_handler_name_, args) -# Lookup table for handler to call when the "get" method (below) is +# Lookup table for handler to call when the ufl_to_lnodes method (below) is # called, depending on the first argument type. _ufl_call_lookup = { ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)),