Skip to content

Commit

Permalink
[Backend] fix differentiation of functions with function arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Willi Braun authored and OpenModelica-Hudson committed Feb 12, 2018
1 parent 3cc00f3 commit 5056d0b
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions Compiler/BackEnd/Differentiate.mo
Expand Up @@ -1365,13 +1365,21 @@ algorithm
cr = ComponentReference.crefPrefixDer(cr);
cr = createDifferentiatedCrefName(cr, inDiffwrtCref, matrixName);
res = Expression.makeCrefExp(cr, tp);
then
(res, inFunctionTree);

b = ComponentReference.crefEqual(DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), inDiffwrtCref);
(zero,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
/* Differentiate with respect to DAE.CREF_IDENT(ident="$") demands zero expressions */
case (DAE.CALL(path=Absyn.IDENT(name = "der"),expLst = {e}), DAE.CREF_IDENT(ident="$"), _, _, _)
equation
(zero,_) = Expression.makeZeroExpression(Expression.arrayDimension(Expression.typeof(e)));
then
(zero, inFunctionTree);

res = if b then zero else res;
case (e as DAE.CALL(attr=DAE.CALL_ATTR(ty=tp)), DAE.CREF_IDENT(ident="$"), _, _, _)
equation
(zero,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
then
(res, inFunctionTree);
(zero, inFunctionTree);

// differentiate builtin calls with 1 argument
case (DAE.CALL(path=Absyn.IDENT(name),attr=DAE.CALL_ATTR(builtin=true),expLst={e}), _, _, _, _)
Expand All @@ -1383,7 +1391,6 @@ algorithm
//print("\nresults to exp: " + s1);
then (res, funcs);


// differentiate builtin calls with N arguments with match
// der(arctan2(y,0)) = der(sign(y)*pi/2) = 0
case (DAE.CALL(path=Absyn.IDENT("atan2"),attr=DAE.CALL_ATTR(builtin=true),expLst={_,e1 as DAE.RCONST(real=0.0)}), _, _, _, _)
Expand Down Expand Up @@ -2301,20 +2308,26 @@ algorithm
expBoolLst = List.threadTuple(expl, blst);
expBoolLst = List.filterOnTrue(expBoolLst, Util.tuple22);
expl1 = List.map(expBoolLst, Util.tuple21);
(dexpl, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), inDiffwrtCref, inInputData, inDiffType, functions);
(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### Selected Arguments: \n");
print(stringDelimitList(List.map(expl1, ExpressionDump.printExpStr), ", ") + "\n");
end if;

if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
print("### differentiated argument list:\n");
(dexpl, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), inDiffwrtCref, inInputData, inDiffType, functions);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### Diffed ExpList: \n");
print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
end if;

(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### Diffed ExpList extended: \n");
print(stringDelimitList(List.map(dexplZero, ExpressionDump.printExpStr), ", ") + "\n");
end if;

e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### differentiated Call :\n");
print(ExpressionDump.printExpStr(e) + "\n");
print("### -> result exp: \n");
Expand Down

0 comments on commit 5056d0b

Please sign in to comment.