Skip to content

Commit

Permalink
[NF] Fix function derivative mapping.
Browse files Browse the repository at this point in the history
- Add lower order derivatives to the function derivative mappings like
  the old frontend does.
  • Loading branch information
perost committed Jul 17, 2019
1 parent 9587ed9 commit f62be2d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 51 deletions.
69 changes: 21 additions & 48 deletions OMCompiler/Compiler/FrontEnd/InstUtil.mo
Expand Up @@ -4294,62 +4294,35 @@ public function addNameToDerivativeMapping
input Absyn.Path path;
output list<DAE.Function> outElts;
algorithm
outElts := match(inElts,path)
local
DAE.Function elt;
list<DAE.FunctionDefinition> funcs;
DAE.Type tp;
Absyn.Path p;
Boolean part,isImpure;
DAE.InlineType inlineType;
DAE.ElementSource source;
Option<SCode.Comment> cmt;
list<DAE.Function> elts;
SCode.Visibility visiblity;

case({},_) then {};

case(DAE.FUNCTION(p,funcs,tp,visiblity,part,isImpure,inlineType,source,cmt)::elts,_)
equation
elts = addNameToDerivativeMapping(elts,path);
funcs = addNameToDerivativeMappingFunctionDefs(funcs,path);
then DAE.FUNCTION(p,funcs,tp,visiblity,part,isImpure,inlineType,source,cmt)::elts;
outElts := list(
match fn
case DAE.FUNCTION()
algorithm
fn.functions := addNameToDerivativeMappingFunctionDefs(fn.functions, path);
then
fn;

case(elt::elts,_)
equation
elts = addNameToDerivativeMapping(elts,path);
then elt::elts;
end match;
else fn;
end match
for fn in inElts);
end addNameToDerivativeMapping;

protected function addNameToDerivativeMappingFunctionDefs " help function to addNameToDerivativeMappingElts"
protected function addNameToDerivativeMappingFunctionDefs " help function to addNameToDerivativeMapping"
input list<DAE.FunctionDefinition> inFuncs;
input Absyn.Path path;
output list<DAE.FunctionDefinition> outFuncs;
algorithm
outFuncs := match(inFuncs,path)
local
DAE.FunctionDefinition func;
Absyn.Path p1,p2;
Integer do;
Option<Absyn.Path> dd;
list<Absyn.Path> lowerOrderDerivatives;
list<tuple<Integer,DAE.derivativeCond>> conds;
list<DAE.FunctionDefinition> funcs;

case({},_) then {};

case(DAE.FUNCTION_DER_MAPPER(p1,p2,do,conds,dd,lowerOrderDerivatives)::funcs,_)
equation
funcs = addNameToDerivativeMappingFunctionDefs(funcs,path);
then DAE.FUNCTION_DER_MAPPER(p1,p2,do,conds,dd,path::lowerOrderDerivatives)::funcs;

case(func::funcs,_)
equation
funcs = addNameToDerivativeMappingFunctionDefs(funcs,path);
then func::funcs;
outFuncs := list(
match fn
case DAE.FUNCTION_DER_MAPPER()
algorithm
fn.lowerOrderDerivatives := path :: fn.lowerOrderDerivatives;
then
fn;

end match;
else fn;
end match
for fn in inFuncs);
end addNameToDerivativeMappingFunctionDefs;

public function getDeriveAnnotation "
Expand Down
27 changes: 27 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFFunction.mo
Expand Up @@ -470,6 +470,33 @@ uniontype Function
end match;
end getCachedFuncs;

function mapCachedFuncs
input InstNode inNode;
input MapFn mapFn;

partial function MapFn
input output Function fn;
end MapFn;
protected
InstNode cls_node;
CachedData cache;
algorithm
cls_node := InstNode.classScope(inNode);
cache := InstNode.getFuncCache(cls_node);

cache := match cache
case CachedData.FUNCTION()
algorithm
cache.funcs := list(mapFn(fn) for fn in cache.funcs);
then
cache;

else fail();
end match;

InstNode.setFuncCache(cls_node, cache);
end mapCachedFuncs;

function isEvaluated
input Function fn;
output Boolean evaluated;
Expand Down
28 changes: 26 additions & 2 deletions OMCompiler/Compiler/NFFrontEnd/NFFunctionDerivative.mo
Expand Up @@ -60,6 +60,7 @@ public
InstNode derivedFn;
Expression order;
list<tuple<Integer, Condition>> conditions;
list<InstNode> lowerOrderDerivatives;
end FUNCTION_DER;

function instDerivatives
Expand Down Expand Up @@ -124,7 +125,7 @@ public
list(conditionToDAE(c) for c in fnDer.conditions),
// TODO: Figure out if the two fields below are needed.
NONE(),
{}
list(Function.name(listHead(Function.getCachedFuncs(fn))) for fn in fnDer.lowerOrderDerivatives)
);
end toDAE;

Expand Down Expand Up @@ -186,9 +187,10 @@ protected
case SCode.Mod.MOD(subModLst = attrs, binding = SOME(Absyn.CREF(acref)))
algorithm
(_, der_node) := Function.instFunction(acref, scope, mod.info);
addLowerOrderDerivative(der_node, fnNode);
(order, conds) := getDerivativeAttributes(attrs, fn, fnNode, mod.info);
then
FUNCTION_DER(der_node, fnNode, order, conds) :: fnDers;
FUNCTION_DER(der_node, fnNode, order, conds, {}) :: fnDers;

// Give a warning if the derivative annotation doesn't specify a function name.
case SCode.Mod.MOD()
Expand Down Expand Up @@ -286,6 +288,28 @@ protected
fail();
end getInputIndex;

function addLowerOrderDerivative
input InstNode fnNode;
input InstNode lowerDerNode;
algorithm
Function.mapCachedFuncs(fnNode, function addLowerOrderDerivative2(lowerDerNode = lowerDerNode));
end addLowerOrderDerivative;

function addLowerOrderDerivative2
input output Function fn;
input InstNode lowerDerNode;
algorithm
fn.derivatives := list(
match fn_der
case FUNCTION_DER()
algorithm
fn_der.lowerOrderDerivatives := lowerDerNode :: fn_der.lowerOrderDerivatives;
then
fn_der;
end match
for fn_der in fn.derivatives);
end addLowerOrderDerivative2;

annotation(__OpenModelica_Interface="frontend");
end NFFunctionDerivative;

2 changes: 1 addition & 1 deletion OMCompiler/Compiler/Script/CevalScriptBackend.mo
Expand Up @@ -3210,7 +3210,7 @@ algorithm
if Flags.isSet(Flags.GC_PROF) then
print(GC.profStatsStr(GC.getProfStats(), head="GC stats after front-end:") + "\n");
end if;
ExecStat.execStat("FrontEnd - DAE generated");
ExecStat.execStat("FrontEnd - DAE generated");
odae := SOME(dae);
else
// Return odae=NONE(); needed to update cache and symbol table if we fail
Expand Down

0 comments on commit f62be2d

Please sign in to comment.