Skip to content

Commit

Permalink
- added an interface function for partial derivatives of functions
Browse files Browse the repository at this point in the history
git-svn-id: https://openmodelica.org/svn/OpenModelica/trunk@18683 f25d12d1-65f4-0310-ae8a-bbce733d8d8e
  • Loading branch information
Willi Braun committed Jan 17, 2014
1 parent f17a938 commit f46533d
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 19 deletions.
3 changes: 2 additions & 1 deletion Compiler/BackEnd/BackendDAE.mo
Expand Up @@ -611,7 +611,8 @@ uniontype DiffentiationType "Define the behavoir of differentation method for (e
record SIMPLE_DIFFERENTAION "Used to solve expression for a cref or by the older jacobian generation, differation w.r.t. a given cref"
end SIMPLE_DIFFERENTAION;

record DIFFERENTAION_FUNCTION "Used to solve expression for a cref or by the older jacobian generation, differation w.r.t. a given cref"
record DIFFERENTAION_FUNCTION "Used to differentiate a function call w.r.t. a given cref, which need to expand the input arguments
by differentiate arguments."
end DIFFERENTAION_FUNCTION;

record FULL_JACOBIAN "Used to generate a full jacobian matrix"
Expand Down
168 changes: 150 additions & 18 deletions Compiler/BackEnd/Differentiate.mo
Expand Up @@ -255,7 +255,6 @@ algorithm
end matchcontinue;
end differentiateExpCrefFunction;


public function differentiateExpCrefFullJacobian "function: differentiateEquationTime
Differentiates an equation with respect to the time variable."
input DAE.Exp inExp;
Expand Down Expand Up @@ -295,6 +294,40 @@ algorithm

end matchcontinue;
end differentiateExpCrefFullJacobian;

public function differentiateFunctionPartial
"function: differentiateFunctionPartial
Differentiates an function with respect to a list of ComponentReference
with are inputs arguments of that function."
input DAE.Function inFunction;
input list<DAE.ComponentRef> inCrefs;
input Absyn.Path inDerFunctionName;
input DAE.FunctionTree inFunctionTree;
output DAE.Function outFunction;
output DAE.FunctionTree outFunctionTree;
algorithm
(outFunction, outFunctionTree) := matchcontinue(inFunction, inCrefs, inDerFunctionName, inFunctionTree)
local
String msg;
DAE.Function dfunction;
BackendDAE.DifferentiateInputData diffData;
Absyn.Path fname;
case (_, _, _, _)
equation
diffData = BackendDAE.DIFFINPUTDATA(NONE(), NONE(), NONE(), NONE(), SOME({}), NONE(), NONE());
(dfunction, outFunctionTree) = differentiatePartialFunction(inFunction, inCrefs, inDerFunctionName, diffData, BackendDAE.DIFFERENTAION_FUNCTION(), inFunctionTree);
then (dfunction, outFunctionTree);
else
equation
true = Flags.isSet(Flags.FAILTRACE);
fname = DAEUtil.functionName(inFunction);
msg = "\nDifferentiate.differentiateFunctionPartial failed for function: " +& Absyn.pathString(fname) +& "\n";
Debug.fprint(Flags.FAILTRACE, msg);
then fail();

end matchcontinue;
end differentiateFunctionPartial;

// =============================================================================
// further interface functions to differentiation
// - differentiateEquation
Expand Down Expand Up @@ -1393,7 +1426,7 @@ algorithm

case (e as DAE.CALL(path = path,expLst = expl), _, _, _, _)
equation
(e1, funcs) = differentiateFunction(e, inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
(e1, funcs) = differentiateFunctionCall(e, inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
(e,_,_,_) = Inline.inlineExp(e1,(SOME(funcs),{DAE.NORM_INLINE(),DAE.NO_INLINE()}),DAE.emptyElementSource/*TODO:Can we propagate source?*/);
then
(e, funcs);
Expand Down Expand Up @@ -1897,7 +1930,7 @@ end differentiateBinary;
// functions to generate derivative of a function
// =============================================================================

protected function differentiateFunction"
protected function differentiateFunctionCall"
Author: Frenkel TUD, wbraun

"
Expand Down Expand Up @@ -1927,27 +1960,20 @@ algorithm
DAE.Type tp, dtp;
list<Boolean> blst;
list<DAE.Type> tlst;
list<tuple<DAE.Exp,Boolean>> expBoolLst;
String typstring, dastring, funstring, str;
list<String> typlststring;
String typstring,dastring, funstring, str;
DAE.TailCall tc;

list<DAE.Element> funcbody, funcbodyalgorithm, funcbodyDer, inputVars, inputVarsNoDer, inputVarsDer, outputVars, outputVarsNoDer, outputVarsDer, protectedVars, protectedVarsNoDer, protectedVarsDer, newProtectedVars, newVarsInOut;
list<DAE.Element> funcbody, funcbodyDer;
list<DAE.Element> inputVars, inputVarsNoDer, inputVarsDer;
list<DAE.Element> outputVars, outputVarsNoDer, outputVarsDer;
list<DAE.Element> protectedVars, protectedVarsNoDer, protectedVarsDer, newProtectedVars;
list<DAE.Statement> bodyStmts, derbodyStmts;

DAE.FunctionDefinition derfuncdef;

DAE.Function func,dfunc;

list<DAE.Function> fns;

list<DAE.Var> typeVars;
list<list<DAE.Var>> typeVarsLst;

list<DAE.Type> typeInOutVars;
list<tuple<DAE.Exp,Boolean>> expBoolLst;

BackendDAE.Variables indepVars;

String funcname;

case (DAE.CALL(path=path,expLst=expl,attr=DAE.CALL_ATTR(tuple_=b,builtin=c,isImpure=isImpure,ty=ty,tailCall=tc)), _, _, _, _)
Expand Down Expand Up @@ -2099,11 +2125,117 @@ algorithm

else
equation
str = "Differentiate.differentiateFunction failed for " +& ExpressionDump.printExpStr(inExp);
str = "Differentiate.differentiateFunctionCall failed for " +& ExpressionDump.printExpStr(inExp);
Debug.fprint(Flags.FAILTRACE, str);
then fail();
end matchcontinue;
end differentiateFunctionCall;


protected function differentiatePartialFunction"
Author: wbraun
"
input DAE.Function inFunction;
input list<DAE.ComponentRef> inDiffwrtCrefs;
input Absyn.Path inDerFunctionName;
input BackendDAE.DifferentiateInputData inInputData;
input BackendDAE.DiffentiationType inDiffType;
input DAE.FunctionTree inFunctionTree;
output DAE.Function outDerFunction;
output DAE.FunctionTree outFunctionTree;
algorithm
(outDerFunction, outFunctionTree) :=
matchcontinue(inFunction, inDiffwrtCrefs, inDerFunctionName, inInputData, inDiffType, inFunctionTree)
local

BackendDAE.DifferentiateInputData inputData, diffFuncData;

Absyn.Path path,dpath;
Boolean isImpure;
DAE.InlineType dinl;
DAE.Type ty;
DAE.FunctionTree functions;
DAE.Type tp;
String str;

list<DAE.Element> funcbody, funcbodyDer;
list<DAE.Element> inputVars, inputVarsNoDer, inputVarsDer;
list<DAE.Element> outputVars, outputVarsNoDer, outputVarsDer;
list<DAE.Element> protectedVars, protectedVarsNoDer, protectedVarsDer, newProtectedVars;
list<DAE.Statement> bodyStmts, derbodyStmts;

DAE.Function func,dfunc;

String funcname;
DAE.ComponentRef diffwrtCref;
list<DAE.ComponentRef> diffwrtCrefs;

case (_, {}, _, _, _, _) then (inFunction, inFunctionTree);

// differentiate function
case (func, diffwrtCref::diffwrtCrefs, dpath, _, _, _)
equation
// debug
//funstring = Tpl.tplString(DAEDumpTpl.dumpFunction, func);
//print("Function: \n" +& funstring +& "\n");

inputVars = DAEUtil.getFunctionInputVars(func);
outputVars = DAEUtil.getFunctionOutputVars(func);
protectedVars = DAEUtil.getFunctionProtectedVars(func);
bodyStmts = DAEUtil.getFunctionAlgorithmStmts(func);

path = DAEUtil.functionName(func);
funcname = Util.modelicaStringToCStr(Absyn.pathString(path), false);
diffFuncData = BackendDAE.DIFFINPUTDATA(NONE(),NONE(),NONE(),NONE(),NONE(),NONE(),SOME(funcname));

(inputVarsDer, functions, inputVarsNoDer, _) = differentiateElementVars(inputVars, diffwrtCref, diffFuncData, BackendDAE.DIFFERENTAION_FUNCTION(), inFunctionTree, {}, {}, {});
(outputVarsDer, functions, outputVarsNoDer, _) = differentiateElementVars(outputVars, diffwrtCref, diffFuncData, BackendDAE.DIFFERENTAION_FUNCTION(), functions, {}, {}, {});
(protectedVarsDer, functions, protectedVarsNoDer, _) = differentiateElementVars(protectedVars, diffwrtCref, diffFuncData, BackendDAE.DIFFERENTAION_FUNCTION(), functions, {}, {}, {});
//print("Finished diffed Vars\n");

//add protected variables to dependent Vars
(inputData,_) = addElementVars2InDep(inputVarsNoDer, functions, diffFuncData);
(inputData,_) = addElementVars2InDep(outputVarsNoDer, functions, inputData);
(inputData,_) = addElementVars2InDep(protectedVarsNoDer, functions, inputData);

// differentiate algorithm statemeants
(derbodyStmts, functions) = differentiateStatements(listReverse(bodyStmts), diffwrtCref, inputData, BackendDAE.DIFFERENTAION_FUNCTION(), {}, functions);

tp = DAEUtil.getFunctionType(func);

//append differentiatet inputsVars to protected vars, since
//for partial differentiation the input arguments are not expanded.
inputVars = listAppend(protectedVars, inputVarsDer);
protectedVars = listAppend(protectedVars, protectedVarsDer);
funcbodyDer = listAppend(inputVars, outputVarsDer);
funcbodyDer = listAppend(funcbodyDer, protectedVars);

//change output vars to protected vars and direction bidir
newProtectedVars = List.map1(outputVars, DAEUtil.setElementVarVisibility, DAE.PROTECTED());
newProtectedVars = List.map1(newProtectedVars, DAEUtil.setElementVarDirection, DAE.BIDIR());
funcbodyDer = listAppend(funcbodyDer, newProtectedVars);

funcbodyDer = listAppend(funcbodyDer, {DAE.ALGORITHM(DAE.ALGORITHM_STMTS(derbodyStmts), DAE.emptyElementSource)});

isImpure = DAEUtil.getFunctionImpureAttribute(func);
dinl = DAEUtil.getFunctionInlineType(func);
dfunc = DAE.FUNCTION(dpath, {DAE.FUNCTION_DEF(funcbodyDer)}, tp, false, isImpure, dinl, DAE.emptyElementSource, NONE());

// debug
//funstring = Tpl.tplString(DAEDumpTpl.dumpFunction, func);
//print("Function: \n" +& funstring +& "\n");
(dfunc, functions) = differentiatePartialFunction(dfunc, diffwrtCrefs, dpath, inInputData, inDiffType, functions);
then
(dfunc, functions);

else
equation
path = DAEUtil.functionName(inFunction);
str = "\nDifferentiate.differentiatePartialFunction failed for function: " +& Absyn.pathString(path) +& "\n";
Debug.fprint(Flags.FAILTRACE, str);
then fail();
end matchcontinue;
end differentiateFunction;
end differentiatePartialFunction;

protected function differentiateElementVars
input list<DAE.Element> inElements; // in as DAE.VAR(_)
Expand Down
20 changes: 20 additions & 0 deletions Compiler/FrontEnd/DAEUtil.mo
Expand Up @@ -2611,6 +2611,26 @@ algorithm
end match;
end getFunctionType;

public function getFunctionImpureAttribute
input DAE.Function fn;
output Boolean outImpure;
algorithm
outImpure := match fn
local
case DAE.FUNCTION(isImpure = outImpure) then outImpure;
end match;
end getFunctionImpureAttribute;

public function getFunctionInlineType
input DAE.Function fn;
output DAE.InlineType outInlineType;
algorithm
outImpure := match fn
local
case DAE.FUNCTION(inlineType = outInlineType) then outInlineType;
end match;
end getFunctionInlineType;

public function getFunctionInputVars
input DAE.Function fn;
output list<DAE.Element> outEls;
Expand Down

0 comments on commit f46533d

Please sign in to comment.