Skip to content
This repository was archived by the owner on May 18, 2019. It is now read-only.

Commit 7b8186a

Browse files
Willi BraunOpenModelica-Hudson
authored andcommitted
[BE] fixing function differentiation and tuple calls in algorithms
Belonging to [master]: - #2203 - OpenModelica/OpenModelica-testsuite#856
1 parent d874cd0 commit 7b8186a

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

Compiler/BackEnd/Differentiate.mo

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,17 @@ algorithm
797797
(derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
798798
then (derivedStatements2, functions);
799799

800+
case (currStatement as DAE.STMT_TUPLE_ASSIGN(expExpLst=expLst, exp=rhs as DAE.CALL(), type_= type_, source=source))::restStatements
801+
equation
802+
(dexpLst,functions) = List.map3Fold(expLst, function differentiateExp(maxIter=maxIter), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
803+
(derivedRHS as DAE.CALL(attr=DAE.CALL_ATTR(ty=type_)), functions) = differentiateExp(rhs, inDiffwrtCref, inInputData, inDiffType, functions, maxIter);
804+
optDerivedStatements1 = {SOME(DAE.STMT_TUPLE_ASSIGN(type_, dexpLst, derivedRHS, source))};
805+
derivedStatements1 = List.flatten(List.map(optDerivedStatements1, List.fromOption));
806+
derivedStatements1 = listAppend(derivedStatements1, {currStatement});
807+
derivedStatements1 = listAppend(derivedStatements1, inStmtsAccum);
808+
(derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
809+
then (derivedStatements2, functions);
810+
800811
case (currStatement as DAE.STMT_ASSIGN_ARR(lhs=lhs, exp=rhs, type_=type_, source=source))::restStatements
801812
equation
802813
(derivedLHS, functions) = differentiateExp(lhs, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter);
@@ -1145,6 +1156,11 @@ algorithm
11451156
//
11461157
// This part contains special rules for GENERIC_GRADIENT()
11471158
//
1159+
case (DAE.CREF(componentRef = cr,ty=tp), DAE.CREF_IDENT(ident="$"), _, BackendDAE.GENERIC_GRADIENT(), _)
1160+
equation
1161+
(res,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
1162+
then
1163+
(res, inFunctionTree);
11481164

11491165
// d(x)/d(x) => generate seed variables
11501166
case ((DAE.CREF(componentRef = cr,ty = tp)), _, BackendDAE.DIFFINPUTDATA(independenentVars=SOME(timevars),matrixName=SOME(matrixName)), BackendDAE.GENERIC_GRADIENT(), _)
@@ -1363,8 +1379,9 @@ algorithm
13631379
then
13641380
(zero, inFunctionTree);
13651381

1366-
/* Exclude records here, they are handled component-wise in differentiateFunctionCall */
1367-
case (e as DAE.CALL(attr=DAE.CALL_ATTR(ty=tp)), DAE.CREF_IDENT(ident="$"), _, _, _)
1382+
/* Exclude records here, they are handled component-wise in differentiateFunctionCall
1383+
and builtin function are handled in differentiateCall* */
1384+
case (e as DAE.CALL(attr=DAE.CALL_ATTR(ty=tp,builtin=false)), DAE.CREF_IDENT(ident="$"), _, _, _)
13681385
guard ( not Expression.isRecordCall(e, inFunctionTree) )
13691386
equation
13701387
(zero,_) = Expression.makeZeroExpression(Expression.arrayDimension(tp));
@@ -2206,7 +2223,7 @@ algorithm
22062223
String funcname;
22072224
list<DAE.FuncArg> falst;
22082225
Integer n;
2209-
DAE.Dimensions dims;
2226+
Boolean success;
22102227

22112228
case (DAE.CALL(path=path,expLst=expl,attr=DAE.CALL_ATTR(tuple_=b,builtin=c,isImpure=isImpure,ty=ty,tailCall=tc)), _, _, _, _)
22122229
equation
@@ -2287,6 +2304,7 @@ algorithm
22872304
else
22882305
(functions, inputVarsDer, _, outputVarsDer, _, blst) = getFunctionInOutVars(func, inFunctionTree, inDiffwrtCref, maxIter);
22892306
(dpath, dtp) = getDiffedTypeandName(func, inputVarsDer, outputVarsDer, blst);
2307+
DAE.T_FUNCTION(funcResultType = dtp) = dtp;
22902308
end if;
22912309

22922310
// debug
@@ -2314,18 +2332,17 @@ algorithm
23142332
print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
23152333
end if;
23162334

2317-
(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
2318-
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
2319-
print("### Diffed ExpList extended: \n");
2320-
print(stringDelimitList(List.map(dexplZero, ExpressionDump.printExpStr), ", ") + "\n");
2335+
// try to create zero expression to fill up the arguments, if it fails use the total differentiation
2336+
(dexplZero, functions, success) = tryZeroDiff(expl1, functions, maxIter);
2337+
if success then
2338+
e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
2339+
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
2340+
else
2341+
exp = DAE.CALL(dpath,listAppend(expl,dexpl),DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
23212342
end if;
23222343

2323-
e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(dtp,b,false,isImpure,false,DAE.NO_INLINE(),tc));
2324-
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
23252344
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
2326-
print("### differentiated Call :\n");
2327-
print(ExpressionDump.printExpStr(e) + "\n");
2328-
print("### -> result exp: \n");
2345+
print("### differentiated result CALL :\n");
23292346
print(ExpressionDump.printExpStr(exp) + "\n");
23302347
end if;
23312348
then
@@ -2340,6 +2357,21 @@ algorithm
23402357
end matchcontinue;
23412358
end differentiateFunctionCallPartial;
23422359

2360+
function tryZeroDiff
2361+
input output list<DAE.Exp> explist;
2362+
input output DAE.FunctionTree functions;
2363+
input Integer maxIter;
2364+
output Boolean success;
2365+
algorithm
2366+
try
2367+
(explist, functions) := List.map3Fold(explist, function differentiateExp(maxIter=maxIter), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
2368+
success := true;
2369+
else
2370+
explist := {};
2371+
success := false;
2372+
end try;
2373+
end tryZeroDiff;
2374+
23432375
protected function createPartialArguments
23442376
input DAE.Type outputType;
23452377
input list<DAE.Exp> inArgs;
@@ -2348,7 +2380,7 @@ protected function createPartialArguments
23482380
input DAE.Exp inCall;
23492381
output DAE.Exp outExp;
23502382
algorithm
2351-
outExp := match(outputType, inCall)
2383+
outExp := matchcontinue(outputType, inCall)
23522384
local
23532385
Absyn.Path path;
23542386
DAE.CallAttributes attr;
@@ -2367,13 +2399,18 @@ algorithm
23672399
expLst = createPartialArgumentsTuple(tys, inArgs, inDiffedArgs, inOrginalExpl, inCall);
23682400
then DAE.TUPLE(expLst);
23692401

2370-
else
2402+
case (_, _)
23712403
equation
23722404
dims = Expression.arrayDimension(outputType);
23732405
(ezero,_) = Expression.makeZeroExpression(dims);
23742406
e = createPartialDifferentiatedExp(inArgs, inDiffedArgs, inOrginalExpl, inCall, 1, ezero);
23752407
then e;
2376-
end match;
2408+
2409+
//else case as fallback create total differentiation call
2410+
case (_, DAE.CALL(path=path, attr=attr))
2411+
then DAE.CALL(path, listAppend(inOrginalExpl,inArgs), attr);
2412+
2413+
end matchcontinue;
23772414
end createPartialArguments;
23782415

23792416
protected function createPartialArgumentsTuple

Compiler/FrontEnd/Types.mo

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2838,7 +2838,9 @@ algorithm
28382838
(fargs1, _) := List.splitOnBoolList(fargs, inBooltLst);
28392839
newfargs := List.threadMap(inElementLst, fargs1, makeElementFarg);
28402840
newfargs := listAppend(fargs, newfargs);
2841-
rettype := makeElementReturnType(inOutputElementLst);
2841+
// The type of DAE.Element.VAR seems to be wrong,
2842+
// but the original type should be also correct
2843+
//rettype := makeElementReturnType(inOutputElementLst);
28422844
outType := DAE.T_FUNCTION(newfargs,rettype,functionAttributes,tysrc);
28432845
end extendsFunctionTypeArgs;
28442846

0 commit comments

Comments
 (0)