@@ -797,6 +797,17 @@ algorithm
797797 (derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
798798 then (derivedStatements2, functions);
799799
800+ case (currStatement as DAE . STMT_TUPLE_ASSIGN (expExpLst= expLst, exp= rhs as DAE . CALL (), type_= type_, source= source))::restStatements
801+ equation
802+ (dexpLst,functions) = List . map3Fold(expLst, function differentiateExp(maxIter= maxIter), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
803+ (derivedRHS as DAE . CALL (attr= DAE . CALL_ATTR (ty= type_)), functions) = differentiateExp(rhs, inDiffwrtCref, inInputData, inDiffType, functions, maxIter);
804+ optDerivedStatements1 = {SOME (DAE . STMT_TUPLE_ASSIGN (type_, dexpLst, derivedRHS, source))};
805+ derivedStatements1 = List . flatten(List . map(optDerivedStatements1, List . fromOption));
806+ derivedStatements1 = listAppend(derivedStatements1, {currStatement});
807+ derivedStatements1 = listAppend(derivedStatements1, inStmtsAccum);
808+ (derivedStatements2, functions) = differentiateStatements(restStatements, inDiffwrtCref, inInputData, inDiffType, derivedStatements1, functions, maxIter);
809+ then (derivedStatements2, functions);
810+
800811 case (currStatement as DAE . STMT_ASSIGN_ARR (lhs= lhs, exp= rhs, type_= type_, source= source))::restStatements
801812 equation
802813 (derivedLHS, functions) = differentiateExp(lhs, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter);
@@ -1145,6 +1156,11 @@ algorithm
11451156 //
11461157 // This part contains special rules for GENERIC_GRADIENT()
11471158 //
1159+ case (DAE . CREF (componentRef = cr,ty= tp), DAE . CREF_IDENT (ident= "$" ), _, BackendDAE . GENERIC_GRADIENT (), _)
1160+ equation
1161+ (res,_) = Expression . makeZeroExpression(Expression . arrayDimension(tp));
1162+ then
1163+ (res, inFunctionTree);
11481164
11491165 // d(x)/d(x) => generate seed variables
11501166 case ((DAE . CREF (componentRef = cr,ty = tp)), _, BackendDAE . DIFFINPUTDATA (independenentVars= SOME (timevars),matrixName= SOME (matrixName)), BackendDAE . GENERIC_GRADIENT (), _)
@@ -1363,8 +1379,9 @@ algorithm
13631379 then
13641380 (zero, inFunctionTree);
13651381
1366- /* Exclude records here, they are handled component-wise in differentiateFunctionCall */
1367- case (e as DAE . CALL (attr= DAE . CALL_ATTR (ty= tp)), DAE . CREF_IDENT (ident= "$" ), _, _, _)
1382+ /* Exclude records here, they are handled component-wise in differentiateFunctionCall
1383+ and builtin function are handled in differentiateCall* */
1384+ case (e as DAE . CALL (attr= DAE . CALL_ATTR (ty= tp,builtin= false )), DAE . CREF_IDENT (ident= "$" ), _, _, _)
13681385 guard ( not Expression . isRecordCall(e, inFunctionTree) )
13691386 equation
13701387 (zero,_) = Expression . makeZeroExpression(Expression . arrayDimension(tp));
@@ -2206,7 +2223,7 @@ algorithm
22062223 String funcname;
22072224 list< DAE . FuncArg > falst;
22082225 Integer n;
2209- DAE . Dimensions dims ;
2226+ Boolean success ;
22102227
22112228 case (DAE . CALL (path= path,expLst= expl,attr= DAE . CALL_ATTR (tuple_= b,builtin= c,isImpure= isImpure,ty= ty,tailCall= tc)), _, _, _, _)
22122229 equation
@@ -2287,6 +2304,7 @@ algorithm
22872304 else
22882305 (functions, inputVarsDer, _, outputVarsDer, _, blst) = getFunctionInOutVars(func , inFunctionTree, inDiffwrtCref, maxIter);
22892306 (dpath, dtp) = getDiffedTypeandName(func , inputVarsDer, outputVarsDer, blst);
2307+ DAE . T_FUNCTION (funcResultType = dtp) = dtp;
22902308 end if ;
22912309
22922310 // debug
@@ -2314,18 +2332,17 @@ algorithm
23142332 print(stringDelimitList(List . map(dexpl, ExpressionDump . printExpStr), ", " ) + " \n " );
23152333 end if ;
23162334
2317- (dexplZero, functions) = List . map3Fold(expl1, function differentiateExp(maxIter= maxIter), DAE . CREF_IDENT ("$" ,DAE . T_REAL_DEFAULT ,{}), BackendDAE . emptyInputData, BackendDAE . GENERIC_GRADIENT (), functions);
2318- if Flags . isSet(Flags . DEBUG_DIFFERENTIATION_VERBOSE ) then
2319- print("### Diffed ExpList extended: \n " );
2320- print(stringDelimitList(List . map(dexplZero, ExpressionDump . printExpStr), ", " ) + " \n " );
2335+ // try to create zero expression to fill up the arguments, if it fails use the total differentiation
2336+ (dexplZero, functions, success) = tryZeroDiff(expl1, functions, maxIter);
2337+ if success then
2338+ e = DAE . CALL (dpath,dexpl,DAE . CALL_ATTR (dtp,b,false ,isImpure,false ,DAE . NO_INLINE (),tc));
2339+ exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
2340+ else
2341+ exp = DAE . CALL (dpath,listAppend(expl,dexpl),DAE . CALL_ATTR (dtp,b,false ,isImpure,false ,DAE . NO_INLINE (),tc));
23212342 end if ;
23222343
2323- e = DAE . CALL (dpath,dexpl,DAE . CALL_ATTR (dtp,b,false ,isImpure,false ,DAE . NO_INLINE (),tc));
2324- exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);
23252344 if Flags . isSet(Flags . DEBUG_DIFFERENTIATION_VERBOSE ) then
2326- print("### differentiated Call : \n " );
2327- print(ExpressionDump . printExpStr(e) + " \n " );
2328- print("### -> result exp: \n " );
2345+ print("### differentiated result CALL : \n " );
23292346 print(ExpressionDump . printExpStr(exp) + " \n " );
23302347 end if ;
23312348 then
@@ -2340,6 +2357,21 @@ algorithm
23402357 end matchcontinue;
23412358end differentiateFunctionCallPartial;
23422359
2360+ function tryZeroDiff
2361+ input output list< DAE . Exp > explist;
2362+ input output DAE . FunctionTree functions;
2363+ input Integer maxIter;
2364+ output Boolean success;
2365+ algorithm
2366+ try
2367+ (explist, functions) := List . map3Fold(explist, function differentiateExp(maxIter= maxIter), DAE . CREF_IDENT ("$" ,DAE . T_REAL_DEFAULT ,{}), BackendDAE . emptyInputData, BackendDAE . GENERIC_GRADIENT (), functions);
2368+ success := true ;
2369+ else
2370+ explist := {};
2371+ success := false ;
2372+ end try ;
2373+ end tryZeroDiff;
2374+
23432375protected function createPartialArguments
23442376 input DAE . Type outputType;
23452377 input list< DAE . Exp > inArgs;
@@ -2348,7 +2380,7 @@ protected function createPartialArguments
23482380 input DAE . Exp inCall;
23492381 output DAE . Exp outExp;
23502382algorithm
2351- outExp := match (outputType, inCall)
2383+ outExp := matchcontinue (outputType, inCall)
23522384 local
23532385 Absyn . Path path;
23542386 DAE . CallAttributes attr;
@@ -2367,13 +2399,18 @@ algorithm
23672399 expLst = createPartialArgumentsTuple(tys, inArgs, inDiffedArgs, inOrginalExpl, inCall);
23682400 then DAE . TUPLE (expLst);
23692401
2370- else
2402+ case (_, _)
23712403 equation
23722404 dims = Expression . arrayDimension(outputType);
23732405 (ezero,_) = Expression . makeZeroExpression(dims);
23742406 e = createPartialDifferentiatedExp(inArgs, inDiffedArgs, inOrginalExpl, inCall, 1 , ezero);
23752407 then e;
2376- end match;
2408+
2409+ // else case as fallback create total differentiation call
2410+ case (_, DAE . CALL (path= path, attr= attr))
2411+ then DAE . CALL (path, listAppend(inOrginalExpl,inArgs), attr);
2412+
2413+ end matchcontinue;
23772414end createPartialArguments;
23782415
23792416protected function createPartialArgumentsTuple
0 commit comments