Skip to content

Commit

Permalink
[BE] fixing function differentiation and tuple calls in algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
Willi Braun authored and OpenModelica-Hudson committed Feb 16, 2018
1 parent d874cd0 commit 7b8186a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
67 changes: 52 additions & 15 deletions Compiler/BackEnd/Differentiate.mo
Expand Up @@ -797,6 +797,17 @@ algorithm
(derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
then (derivedStatements2, functions);

case (currStatement as DAE.STMT_TUPLE_ASSIGN(expExpLst=expLst, exp=rhs as DAE.CALL(), type_= type_, source=source))::restStatements
equation
(dexpLst,functions) = List.map3Fold(expLst, function differentiateExp(maxIter=maxIter), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
(derivedRHS as DAE.CALL(attr=DAE.CALL_ATTR(ty=type_)), functions) = differentiateExp(rhs, inDiffwrtCref, inInputData, inDiffType, functions, maxIter);
optDerivedStatements1 = {SOME(DAE.STMT_TUPLE_ASSIGN(type_, dexpLst, derivedRHS, source))};
derivedStatements1 = List.flatten(List.map(optDerivedStatements1, List.fromOption));
derivedStatements1 = listAppend(derivedStatements1, {currStatement});
derivedStatements1 = listAppend(derivedStatements1, inStmtsAccum);
(derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
then (derivedStatements2, functions);

case (currStatement as DAE.STMT_ASSIGN_ARR(lhs=lhs, exp=rhs, type_=type_, source=source))::restStatements
equation
(derivedLHS, functions) = differentiateExp(lhs, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter);
Expand Down Expand Up @@ -1145,6 +1156,11 @@ algorithm
//
// This part contains special rules for GENERIC_GRADIENT()
//
case (DAE.CREF(componentRef = cr,ty=tp), DAE.CREF_IDENT(ident="$"), _, BackendDAE.GENERIC_GRADIENT(), _)
equation
(res,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
then
(res, inFunctionTree);

// d(x)/d(x) => generate seed variables
case ((DAE.CREF(componentRef = cr,ty = tp)), _, BackendDAE.DIFFINPUTDATA(independenentVars=SOME(timevars),matrixName=SOME(matrixName)), BackendDAE.GENERIC_GRADIENT(), _)
Expand Down Expand Up @@ -1363,8 +1379,9 @@ algorithm
then
(zero, inFunctionTree);

/* Exclude records here, they are handled component-wise in differentiateFunctionCall */
case (e as DAE.CALL(attr=DAE.CALL_ATTR(ty=tp)), DAE.CREF_IDENT(ident="$"), _, _, _)
/* Exclude records here, they are handled component-wise in differentiateFunctionCall
and builtin function are handled in differentiateCall* */
case (e as DAE.CALL(attr=DAE.CALL_ATTR(ty=tp,builtin=false)), DAE.CREF_IDENT(ident="$"), _, _, _)
guard ( not Expression.isRecordCall(e, inFunctionTree) )
equation
(zero,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
Expand Down Expand Up @@ -2206,7 +2223,7 @@ algorithm
String funcname;
list<DAE.FuncArg> falst;
Integer n;
DAE.Dimensions dims;
Boolean success;

case (DAE.CALL(path=path,expLst=expl,attr=DAE.CALL_ATTR(tuple_=b,builtin=c,isImpure=isImpure,ty=ty,tailCall=tc)), _, _, _, _)
equation
Expand Down Expand Up @@ -2287,6 +2304,7 @@ algorithm
else
(functions, inputVarsDer, _, outputVarsDer, _, blst) = getFunctionInOutVars(func, inFunctionTree, inDiffwrtCref, maxIter);
(dpath, dtp) = getDiffedTypeandName(func, inputVarsDer, outputVarsDer, blst);
DAE.T_FUNCTION(funcResultType = dtp) = dtp;
end if;

// debug
Expand Down Expand Up @@ -2314,18 +2332,17 @@ algorithm
print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
end if;

(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### Diffed ExpList extended: \n");
print(stringDelimitList(List.map(dexplZero, ExpressionDump.printExpStr), ", ") + "\n");
// try to create zero expression to fill up the arguments, if it fails use the total differentiation
(dexplZero, functions, success) = tryZeroDiff(expl1, functions, maxIter);
if success then
e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
else
exp = DAE.CALL(dpath,listAppend(expl,dexpl),DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
end if;

e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### differentiated Call :\n");
print(ExpressionDump.printExpStr(e) + "\n");
print("### -> result exp: \n");
print("### differentiated result CALL :\n");
print(ExpressionDump.printExpStr(exp) + "\n");
end if;
then
Expand All @@ -2340,6 +2357,21 @@ algorithm
end matchcontinue;
end differentiateFunctionCallPartial;

function tryZeroDiff
input output list<DAE.Exp> explist;
input output DAE.FunctionTree functions;
input Integer maxIter;
output Boolean success;
algorithm
try
(explist, functions) := List.map3Fold(explist, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
success := true;
else
explist := {};
success := false;
end try;
end tryZeroDiff;

protected function createPartialArguments
input DAE.Type outputType;
input list<DAE.Exp> inArgs;
Expand All @@ -2348,7 +2380,7 @@ protected function createPartialArguments
input DAE.Exp inCall;
output DAE.Exp outExp;
algorithm
outExp := match(outputType, inCall)
outExp := matchcontinue(outputType, inCall)
local
Absyn.Path path;
DAE.CallAttributes attr;
Expand All @@ -2367,13 +2399,18 @@ algorithm
expLst = createPartialArgumentsTuple(tys, inArgs, inDiffedArgs, inOrginalExpl, inCall);
then DAE.TUPLE(expLst);

else
case (_, _)
equation
dims = Expression.arrayDimension(outputType);
(ezero,_) = Expression.makeZeroExpression(dims);
e = createPartialDifferentiatedExp(inArgs, inDiffedArgs, inOrginalExpl, inCall, 1, ezero);
then e;
end match;

//else case as fallback create total differentiation call
case (_, DAE.CALL(path=path, attr=attr))
then DAE.CALL(path, listAppend(inOrginalExpl,inArgs), attr);

end matchcontinue;
end createPartialArguments;

protected function createPartialArgumentsTuple
Expand Down
4 changes: 3 additions & 1 deletion Compiler/FrontEnd/Types.mo
Expand Up @@ -2838,7 +2838,9 @@ algorithm
(fargs1, _) := List.splitOnBoolList(fargs, inBooltLst);
newfargs := List.threadMap(inElementLst, fargs1, makeElementFarg);
newfargs := listAppend(fargs, newfargs);
rettype := makeElementReturnType(inOutputElementLst);
// The type of DAE.Element.VAR seems to be wrong,
// but the original type should be also correct
//rettype := makeElementReturnType(inOutputElementLst);
outType := DAE.T_FUNCTION(newfargs,rettype,functionAttributes,tysrc);
end extendsFunctionTypeArgs;

Expand Down

0 comments on commit 7b8186a

Please sign in to comment.