Skip to content

Commit

Permalink
Frontend support for partial derivatives of functions (#10649)
Browse files Browse the repository at this point in the history
- Add frontend support for partial derivatives of functions
  (`df = der(f, x)`) for use by the new backend. Only enabled when
  using the `--newBackend` flag since the old backend would handle them
  incorrectly.
  • Loading branch information
perost committed May 4, 2023
1 parent af8b2da commit c254033
Show file tree
Hide file tree
Showing 18 changed files with 290 additions and 49 deletions.
5 changes: 5 additions & 0 deletions OMCompiler/Compiler/FrontEnd/DAE.mo
Expand Up @@ -470,6 +470,11 @@ public uniontype FunctionDefinition
ComponentRef inputParam "The input parameter the inverse is for";
Exp inverseCall "The inverse function call";
end FUNCTION_INVERSE;

record FUNCTION_PARTIAL_DERIVATIVE
Absyn.Path derivedFunction;
list<String> derivedVars;
end FUNCTION_PARTIAL_DERIVATIVE;
end FunctionDefinition;

public
Expand Down
78 changes: 39 additions & 39 deletions OMCompiler/Compiler/NFFrontEnd/NFBuiltinFuncs.mo

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFClass.mo
Expand Up @@ -213,6 +213,12 @@ constant Prefixes DEFAULT_PREFIXES = Prefixes.PREFIXES(
cls := match cls
case INSTANCED_CLASS()
then INSTANCED_CLASS(cls.ty, cls.elements, sections, cls.prefixes, cls.restriction);

case TYPED_DERIVED()
algorithm
InstNode.classApply(cls.baseClass, setSections, sections);
then
cls;
end match;
end setSections;

Expand Down
8 changes: 8 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFConvertDAE.mo
Expand Up @@ -1091,6 +1091,14 @@ algorithm
cls := InstNode.getClass(Function.instance(func));

dfunc := match cls
case Class.TYPED_DERIVED(restriction = Restriction.FUNCTION())
guard Function.isPartialDerivative(func)
algorithm
def := DAE.FunctionDefinition.FUNCTION_PARTIAL_DERIVATIVE(
Function.getDerivedFunctionName(func), Function.getDerivedInputNames(func));
then
Function.toDAE(func, def);

case Class.INSTANCED_CLASS(sections = sections, restriction = Restriction.FUNCTION())
algorithm
elems := convertFunctionParams(func.inputs, {});
Expand Down
4 changes: 4 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFEvalFunction.mo
Expand Up @@ -81,6 +81,10 @@ function evaluate
algorithm
if Function.isExternal(fn) then
result := evaluateExternal(fn, args, target);
elseif Function.isPartialDerivative(fn) then
// Partial derivatives of functions are differentiated by the backend, so
// make sure we don't try to evaluate the non-differentiated function body.
fail();
else
result := evaluateNormal(fn, args);
end if;
Expand Down
6 changes: 6 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Expand Up @@ -2533,6 +2533,12 @@ algorithm
for fn_inv in fn.inverses loop
funcs := collectExpFuncs(fn_inv.inverseCall, funcs);
end for;

if Function.isPartialDerivative(fn) then
for f in Function.getCachedFuncs(Class.lastBaseClass(fn.node)) loop
flattenFunction(f, funcs);
end for;
end if;
end if;
end if;
end flattenFunction;
Expand Down
101 changes: 93 additions & 8 deletions OMCompiler/Compiler/NFFrontEnd/NFFunction.mo
Expand Up @@ -270,6 +270,7 @@ uniontype Function
Type returnType;
DAE.FunctionAttributes attributes;
list<FunctionDerivative> derivatives;
list<Integer> derivedInputs;
array<FunctionInverse> inverses;
Pointer<FunctionStatus> status;
Pointer<Integer> callCounter "Used during function evaluation to limit recursion.";
Expand All @@ -291,7 +292,7 @@ uniontype Function
// Make sure builtin functions aren't added to the function tree.
status := if isBuiltinAttr(attr) then FunctionStatus.COLLECTED else FunctionStatus.INITIAL;
fn := FUNCTION(path, node, inputs, outputs, locals, {}, Type.UNKNOWN(),
attr, {}, listArray({}), Pointer.create(status), Pointer.create(0));
attr, {}, {}, listArray({}), Pointer.create(status), Pointer.create(0));
end new;

function lookupFunctionSimple
Expand Down Expand Up @@ -467,6 +468,7 @@ uniontype Function
specialBuiltin := isSpecialBuiltin(fn);
fn.derivatives := FunctionDerivative.instDerivatives(fnNode, fn);
fn.inverses := FunctionInverse.instInverses(fnNode, fn);
fn.derivedInputs := instPartialDerivedVars(def.classDef, fn.inputs, fn, context, info);
fnNode := InstNode.cacheAddFunc(fnNode, fn, specialBuiltin);
then
(fnNode, specialBuiltin);
Expand Down Expand Up @@ -728,15 +730,21 @@ uniontype Function
algorithm
if isDefaultRecordConstructor(fn) then
s := IOStream.append(s, InstNode.toFlatString(fn.node));
elseif isPartialDerivative(fn) then
fn_name := if stringEmpty(overrideName) then Util.makeQuotedIdentifier(AbsynUtil.pathString(fn.path)) else overrideName;

s := IOStream.append(s, "function ");
s := IOStream.append(s, fn_name);
s := IOStream.append(s, " = der(");
s := IOStream.append(s, Util.makeQuotedIdentifier(AbsynUtil.pathString(getDerivedFunctionName(fn))));
s := IOStream.append(s, ", ");
s := IOStream.append(s, stringDelimitList(getDerivedInputNames(fn), ", "));
s := FlatModelicaUtil.appendComment(SCodeUtil.getElementComment(InstNode.definition(fn.node)), s);
s := IOStream.append(s, ")");
else
cmt := Util.getOptionOrDefault(SCodeUtil.getElementComment(InstNode.definition(fn.node)), SCode.COMMENT(NONE(), NONE()));
fn_name := if stringEmpty(overrideName) then Util.makeQuotedIdentifier(AbsynUtil.pathString(fn.path)) else overrideName;

fn_name := AbsynUtil.pathString(fn.path);
if stringEmpty(overrideName) then
fn_name := Util.makeQuotedIdentifier(fn_name);
else
fn_name := overrideName;
end if;
s := IOStream.append(s, "function ");
s := IOStream.append(s, fn_name);
s := FlatModelicaUtil.appendCommentString(SOME(cmt), s);
Expand Down Expand Up @@ -1394,7 +1402,7 @@ uniontype Function
if not isTyped(fn) then
// Type all the components in the function.
Typing.typeClassType(node, NFBinding.EMPTY_BINDING, context, node);
Typing.typeComponents(node, context);
Typing.typeComponents(node, context, preserveDerived = isPartialDerivative(fn));
if InstNode.isPartial(node) then
ClassTree.applyComponents(Class.classTree(InstNode.getClass(node)), boxFunctionParameter);
Expand All @@ -1403,6 +1411,7 @@ uniontype Function
// Make the slots and return type for the function.
fn.slots := makeSlots(fn.inputs);
checkParamTypes(fn);
checkPartialDerivativeTypes(fn);
fn.returnType := makeReturnType(fn);
end if;
end typeFunctionSignature;
Expand Down Expand Up @@ -1759,6 +1768,31 @@ uniontype Function
end if;
end isExternalObjectConstructorOrDestructor;
function isPartialDerivative
"Returns true if the function is a partial derivative of a function, df = der(f, x)."
input Function fn;
output Boolean res = not listEmpty(fn.derivedInputs);
end isPartialDerivative;
function getDerivedInputNames
"Returns the names of the differentiated inputs in a partial derivative,
df = der(f, x, y, z) => {\"x\", \"y\", \"z\"}"
input Function fn;
output list<String> names = {};
algorithm
for i in fn.derivedInputs loop
names := InstNode.name(listGet(fn.inputs, i)) :: names;
end for;

names := listReverseInPlace(names);
end getDerivedInputNames;

function getDerivedFunctionName
"Returns the name of the derived function in a partial derivative, df = der(f, x) => f"
input Function fn;
output Absyn.Path name = InstNode.fullPath(Class.lastBaseClass(fn.node), ignoreBaseClass = true);
end getDerivedFunctionName;

function inlineBuiltin
input Function fn;
output DAE.InlineType inlineType;
Expand Down Expand Up @@ -2379,6 +2413,24 @@ protected
end match;
end isValidParamState;

function checkPartialDerivativeTypes
input Function fn;
protected
InstNode node;
Type ty;
algorithm
for i in fn.derivedInputs loop
node := listGet(fn.inputs, i);
ty := InstNode.getType(node);

if not (Type.isReal(ty) and Type.isScalar(ty)) then
Error.addSourceMessage(Error.PARTIAL_DERIVATIVE_INPUT_INVALID_TYPE,
{InstNode.name(node), AbsynUtil.pathString(getDerivedFunctionName(fn))}, InstNode.info(fn.node));
fail();
end if;
end for;
end checkPartialDerivativeTypes;

public function makeReturnType
input Function fn;
output Type returnType;
Expand Down Expand Up @@ -2818,6 +2870,39 @@ protected
else ();
end match;
end checkUseBeforeAssignExp_traverse;

function instPartialDerivedVars
input SCode.ClassDef classDef;
input list<InstNode> inputs;
input Function fn;
input InstContext.Type context;
input SourceInfo info;
output list<Integer> derivedVars = {};
protected
Integer index;
algorithm
() := match classDef
case SCode.ClassDef.PDER()
algorithm
for var in classDef.derivedVariables loop
index := List.positionOnTrue(inputs, function InstNode.isNamed(name = var));

if index < 1 then
Error.addSourceMessage(Error.PARTIAL_DERIVATIVE_INPUT_NOT_FOUND,
{var, AbsynUtil.pathString(getDerivedFunctionName(fn))}, info);
fail();
end if;

derivedVars := index :: derivedVars;
end for;

derivedVars := listReverseInPlace(derivedVars);
then
();

else ();
end match;
end instPartialDerivedVars;
end Function;

annotation(__OpenModelica_Interface="frontend");
Expand Down
9 changes: 9 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFInst.mo
Expand Up @@ -547,6 +547,15 @@ algorithm
// get here through e.g. the interactive API. In that case just ignore them.
case SCode.OVERLOAD() then node;

// A partial derivative of a function, function df = der(f, x).
// Treat it as a short class definition here,
case SCode.PDER()
guard Flags.getConfigBool(Flags.NEW_BACKEND)
then expandClassDerived(def,
SCode.ClassDef.DERIVED(Absyn.TypeSpec.TPATH(cdef.functionPath, NONE()),
SCode.NOMOD(), SCode.defaultVarAttr),
node, info);

else
algorithm
Error.assertion(false, getInstanceName() + " got unknown class:\n" + SCodeDump.unparseElementStr(def), sourceInfo());
Expand Down
15 changes: 15 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo
Expand Up @@ -526,6 +526,21 @@ uniontype InstNode
end match;
end name;

function isNamed
input InstNode node;
input String name;
output Boolean res;
algorithm
res := match node
case CLASS_NODE() then node.name == name;
case COMPONENT_NODE() then node.name == name;
case INNER_OUTER_NODE() then isNamed(node.innerNode, name);
case VAR_NODE() then node.name == name;
case NAME_NODE() then node.name == name;
else false;
end match;
end isNamed;

function className
input InstNode node;
output String name;
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFRecord.mo
Expand Up @@ -163,7 +163,7 @@ algorithm
attr := DAE.FUNCTION_ATTRIBUTES_DEFAULT;
status := Pointer.create(FunctionStatus.INITIAL);
InstNode.cacheAddFunc(node, Function.FUNCTION(path, ctor_node, inputs,
{out_rec}, locals, {}, Type.UNKNOWN(), attr, {}, listArray({}), status, Pointer.create(0)), false);
{out_rec}, locals, {}, Type.UNKNOWN(), attr, {}, {}, listArray({}), status, Pointer.create(0)), false);
end instDefaultConstructor;

function checkLocalFieldOrder
Expand Down
4 changes: 3 additions & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFTyping.mo
Expand Up @@ -130,6 +130,7 @@ end typeClass;
function typeComponents
input InstNode cls;
input InstContext.Type context;
input Boolean preserveDerived = false;
protected
Class c = InstNode.getClass(cls), c2;
ClassTree cls_tree;
Expand Down Expand Up @@ -158,7 +159,8 @@ algorithm

// For derived types with dimensions we keep them as they are, because we
// need to preserve the dimensions.
case Class.TYPED_DERIVED(ty = Type.ARRAY())
case Class.TYPED_DERIVED()
guard preserveDerived or Type.isArray(c.ty)
algorithm
typeComponents(c.baseClass, context);
then
Expand Down
5 changes: 5 additions & 0 deletions OMCompiler/Compiler/Template/DAEDumpTV.mo
Expand Up @@ -998,6 +998,11 @@ package DAE
ComponentRef inputParam "The input parameter the inverse is for";
Exp inverseCall "The inverse function call";
end FUNCTION_INVERSE;

record FUNCTION_PARTIAL_DERIVATIVE
Absyn.Path derivedFunction;
list<String> derivedVars;
end FUNCTION_PARTIAL_DERIVATIVE;
end FunctionDefinition;

uniontype derivativeCond "Different conditions on derivatives"
Expand Down
13 changes: 13 additions & 0 deletions OMCompiler/Compiler/Template/DAEDumpTpl.tpl
Expand Up @@ -67,6 +67,14 @@ end dumpFunctions;
template dumpFunction(DAE.Function function)
::=
match function
case FUNCTION(functions = {FUNCTION_PARTIAL_DERIVATIVE(__)}) then
let fn_name = AbsynDumpTpl.dumpPathNoQual(path)
let cmt_str = dumpCommentOpt(comment)
let ann_str = dumpClassAnnotation(comment)
let impure_str = if isImpure then 'impure '
<<
<%impure_str%>function <%fn_name%> = <%dumpFunctionDefinitions(functions)%><%cmt_str%><%if ann_str then " "%><%ann_str%>;
>>
case FUNCTION(__) then
let cmt_str = dumpCommentOpt(comment)
let ann_str = dumpClassAnnotation(comment)
Expand Down Expand Up @@ -111,6 +119,11 @@ match functions
>>
case FUNCTION_DER_MAPPER(__) then ''
case FUNCTION_INVERSE(__) then ''
case FUNCTION_PARTIAL_DERIVATIVE(__) then
let vars = (derivedVars |> var => var ;separator=", ")
<<
der(<%AbsynDumpTpl.dumpPathNoQual(derivedFunction)%>, <%vars%>)
>>
end dumpFunctionDefinition;

template dumpExternalDecl(ExternalDecl externalDecl)
Expand Down
4 changes: 4 additions & 0 deletions OMCompiler/Compiler/Util/Error.mo
Expand Up @@ -881,6 +881,10 @@ public constant ErrorTypes.Message NF_PDE_NOT_IMPLEMENTED = ErrorTypes.MESSAGE(4
Gettext.gettext("PDEModelica is not yet supported by the new front-end, using the old front-end instead."));
public constant ErrorTypes.Message NON_CONSTANT_IN_ENCLOSING_SCOPE = ErrorTypes.MESSAGE(403, ErrorTypes.TRANSLATION(), ErrorTypes.ERROR(),
Gettext.gettext("Component ‘%s‘ was found in an enclosing scope but is not a constant."));
public constant ErrorTypes.Message PARTIAL_DERIVATIVE_INPUT_NOT_FOUND = ErrorTypes.MESSAGE(404, ErrorTypes.TRANSLATION(), ErrorTypes.ERROR(),
Gettext.gettext("‘%s‘ in partial derivative of ‘%s‘ does not name an input parameter of the function."));
public constant ErrorTypes.Message PARTIAL_DERIVATIVE_INPUT_INVALID_TYPE = ErrorTypes.MESSAGE(405, ErrorTypes.TRANSLATION(), ErrorTypes.ERROR(),
Gettext.gettext("‘%s‘ in partial derivative of ‘%s‘ is not a scalar Real input parameter of the function."));

public constant ErrorTypes.Message INITIALIZATION_NOT_FULLY_SPECIFIED = ErrorTypes.MESSAGE(496, ErrorTypes.TRANSLATION(), ErrorTypes.WARNING(),
Gettext.gettext("The initial conditions are not fully specified. %s."));
Expand Down
@@ -0,0 +1,24 @@
// name: FunctionPartialDerivative1
// keywords:
// status: correct
// cflags: -d=newInst, --newBackend
//

model FunctionPartialDerivative1
function f
input Real x;
output Real y = x^2;
end f;

function df = der(f, x);

Real y = df(0);
end FunctionPartialDerivative1;

// Result:
// function FunctionPartialDerivative1.df = der(FunctionPartialDerivative1.f, x);
//
// class FunctionPartialDerivative1
// Real y = FunctionPartialDerivative1.df(0.0);
// end FunctionPartialDerivative1;
// endResult
@@ -0,0 +1,26 @@
// name: FunctionPartialDerivative2
// keywords:
// status: incorrect
// cflags: -d=newInst, --newBackend
//

model FunctionPartialDerivative2
function f
input Real x;
output Real y = x^2;
end f;

function df = der(f, y);

Real y = df(0);
end FunctionPartialDerivative2;

// Result:
// Error processing file: FunctionPartialDerivative2.mo
// [flattening/modelica/scodeinst/FunctionPartialDerivative2.mo:15:3-15:17:writable] Error: ‘y‘ in partial derivative of ‘FunctionPartialDerivative2.f‘ does not name an input parameter of the function.
//
// # Error encountered! Exiting...
// # Please check the error message and the flags.
//
// Execution failed!
// endResult

0 comments on commit c254033

Please sign in to comment.