Skip to content

Commit

Permalink
partial differentiation of function with record outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Willi authored and OpenModelica-Hudson committed Nov 30, 2018
1 parent 210af85 commit 0a46029
Showing 1 changed file with 43 additions and 9 deletions.
52 changes: 43 additions & 9 deletions Compiler/BackEnd/Differentiate.mo
Expand Up @@ -555,7 +555,7 @@ algorithm
*/
(outDiffedExp, outFunctionTree) := match inExp
local
Absyn.Path p;
Absyn.Path p, p1, p2;
Boolean b;
DAE.CallAttributes attr;
DAE.Exp e1, e2, e3, actual, simplified;
Expand All @@ -566,6 +566,7 @@ algorithm
Integer i;
String s1, s2, stp;
list<String> strLst;
list<DAE.Var> varLst;
//String se1;
list<DAE.Exp> sub, expl;
list<list<DAE.Exp>> matrix, dmatrix;
Expand Down Expand Up @@ -684,15 +685,25 @@ algorithm
then (res, functionTree);


// differentiate tsub
// differentiate rsub
case e1 as DAE.RSUB()
algorithm
(res1, functionTree) := differentiateExp(e1.exp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
if not referenceEq(e1.exp, res1) then
e1.exp := res1;
(e1,_) := ExpressionSimplify.simplify1(e1);
try
(expl, strLst) := match res1
case DAE.RECORD(exps=expl,comp=strLst) then (expl, strLst);
case DAE.CALL(path=p1,expLst=expl,attr=DAE.CALL_ATTR(ty=DAE.T_COMPLEX(complexClassType=ClassInf.RECORD(path=p2), varLst=varLst)))
guard Absyn.pathEqual(p1,p2)
then (expl, list(v.name for v in varLst));
end match;
res := listGet(expl, List.position1OnTrue(strLst, stringEq, e1.fieldName));
else
e1.exp := res1;
(res,_) := ExpressionSimplify.simplify1(e1);
end try;
end if;
then (e1, functionTree);
then (res, functionTree);

// differentiate tuple
case DAE.TUPLE(PR=expl) equation
Expand Down Expand Up @@ -2393,15 +2404,21 @@ protected function createPartialArguments
algorithm
outExp := matchcontinue(outputType, inCall)
local
Absyn.Path path;
Absyn.Path path, rPath;
DAE.CallAttributes attr;
list<DAE.Exp> expLst;
DAE.Exp ezero, e;
DAE.Dimensions dims;
list<DAE.Type> tys;
list<DAE.Var> varLst;
list<String> varNames;

case (DAE.T_COMPLEX(complexClassType=ClassInf.RECORD()), DAE.CALL(path=path, attr=attr))
then DAE.CALL(path, listAppend(inOrginalExpl,inArgs), attr);
case (DAE.T_COMPLEX(complexClassType=ClassInf.RECORD(path=rPath),varLst=varLst), DAE.CALL(path=path, attr=attr))
equation
tys = list(DAEUtil.varType(v) for v in varLst);
varNames = list(DAEUtil.typeVarIdent(v) for v in varLst);
expLst = createPartialArgumentsRecord(tys, varNames, inArgs, inDiffedArgs, inOrginalExpl, inCall);
then DAE.RECORD(rPath, expLst, varNames, outputType);

case (DAE.T_COMPLEX(complexClassType=ClassInf.RECORD()), DAE.TSUB(exp=DAE.CALL(path=path, attr=attr)))
then DAE.CALL(path, listAppend(inOrginalExpl,inArgs), attr);
Expand Down Expand Up @@ -2438,6 +2455,19 @@ algorithm
threaded for tp in inTypesLst, number in 1:listLength(inTypesLst));
end createPartialArgumentsTuple;

protected function createPartialArgumentsRecord
input list<DAE.Type> inTypesLst;
input list<DAE.String> inVarNames;
input list<DAE.Exp> inArgs;
input list<DAE.Exp> inDiffedArgs;
input list<DAE.Exp> inOrginalExpl;
input DAE.Exp inCall;
output list<DAE.Exp> outExpLst;
algorithm
outExpLst := list( createPartialArguments(tp, inArgs, inDiffedArgs, inOrginalExpl, (DAE.RSUB(inCall, -1, name, tp)) )
threaded for tp in inTypesLst, name in inVarNames);
end createPartialArgumentsRecord;

protected function createPartialDifferentiatedExp
"Generates an expression with a sum partial derivatives."
input list<DAE.Exp> inDiffExpl;
Expand Down Expand Up @@ -2518,8 +2548,12 @@ algorithm
DAE.CallAttributes attr;
DAE.Type ty;
Integer ix;
String name;

case DAE.RSUB(exp=DAE.CALL(path=path, attr=attr), ix=ix, fieldName=name, ty=ty)
then DAE.RSUB(DAE.CALL(path, expLst, attr), ix, name, ty);

case DAE.TSUB(exp=DAE.CALL(path=path, attr=attr), ix =ix, ty=ty)
case DAE.TSUB(exp=DAE.CALL(path=path, attr=attr), ix=ix, ty=ty)
then DAE.TSUB(DAE.CALL(path, expLst, attr), ix, ty);

case DAE.CALL(path=path, attr=attr) equation
Expand Down

0 comments on commit 0a46029

Please sign in to comment.