Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
niklwors committed Oct 15, 2015
2 parents cb8e341 + ef387e7 commit 3c084f6
Show file tree
Hide file tree
Showing 16 changed files with 503 additions and 95 deletions.
174 changes: 109 additions & 65 deletions Compiler/BackEnd/Differentiate.mo
Expand Up @@ -60,6 +60,7 @@ protected import BackendVariable;
protected import ClassInf;
protected import ComponentReference;
protected import DAEDump;
protected import DAEDumpTpl;
protected import Debug;
protected import Error;
protected import Expression;
Expand Down Expand Up @@ -99,11 +100,17 @@ protected
BackendDAE.Variables knvars;
algorithm
try
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrEqnStr("### Differentiate equation\n", inEquation, " w.r.t. time.\n");
end if;
funcs := BackendDAEUtil.getFunctions(inShared);
knvars := BackendDAEUtil.getknvars(inShared);
diffData := BackendDAE.DIFFINPUTDATA(NONE(), SOME(inVariables), SOME(knvars), SOME(inVariables), {}, {}, NONE());
(outEquation, funcs) := differentiateEquation(inEquation, DAE.crefTime, diffData, BackendDAE.DIFFERENTIATION_TIME(), funcs);
outShared := BackendDAEUtil.setSharedFunctionTree(inShared, funcs);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrEqnStr("### Result of differentiation\n --> ", outEquation, "\n");
end if;
else
msg := "\nDifferentiate.differentiateEquationTime failed for " + BackendDump.equationString(inEquation) + "\n\n";
source := BackendEquation.equationSource(inEquation);
Expand All @@ -126,12 +133,18 @@ protected
BackendDAE.Variables knvars;
algorithm
try
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStr("### Differentiate expression\n ", inExp, " w.r.t. time.\n");
end if;
funcs := BackendDAEUtil.getFunctions(inShared);
knvars := BackendDAEUtil.getknvars(inShared);
diffData := BackendDAE.DIFFINPUTDATA(NONE(), SOME(inVariables), SOME(knvars), SOME(inVariables), {}, {}, NONE());
(dexp, funcs) := differentiateExp(inExp, DAE.crefTime, diffData, BackendDAE.DIFFERENTIATION_TIME(), funcs, defaultMaxIter, {});
(outExp, _) := ExpressionSimplify.simplify(dexp);
outShared := BackendDAEUtil.setSharedFunctionTree(inShared, funcs);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStr("### Result of differentiation\n --> ", outExp, "n");
end if;
else
// expandDerOperator expects sometime that differentiate fails,
// so the calling function need to take care of the error messages.
Expand All @@ -146,7 +159,7 @@ algorithm
end differentiateExpTime;

public function differentiateExpSolve
"Differentiates an equation with respect to inCref."
"Differentiates an expression with respect to inCref."
input DAE.Exp inExp;
input DAE.ComponentRef inCref;
input Option<DAE.FunctionTree> functions;
Expand All @@ -166,8 +179,14 @@ algorithm
else DAE.emptyFuncTree;
end match;

if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStrCrefStr("### Differentiate expression\n ", inExp, " w.r.t. ", inCref, "\n");
end if;
(dexp, _) := differentiateExp(inExp, inCref, BackendDAE.emptyInputData, BackendDAE.SIMPLE_DIFFERENTIATION(), fun, defaultMaxIter, {});
(outExp, _) := ExpressionSimplify.simplify(dexp);
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStr("### Result of differentiation\n --> ", outExp, "\n");
end if;
else
if Flags.isSet(Flags.FAILTRACE) then
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {ExpressionDump.printExpStr(inExp), ComponentReference.crefStr(inCref)}, sourceInfo());
Expand Down Expand Up @@ -220,41 +239,6 @@ end differentiateExpCrefFullJacobian;
// =============================================================================


protected function differentiateEquations
"Differentiates an equation with respect to a cref."
input list<BackendDAE.Equation> inEquations;
input DAE.ComponentRef inDiffwrtCref;
input BackendDAE.DifferentiateInputData inInputData;
input BackendDAE.DifferentiationType inDiffType;
input list<BackendDAE.Equation> inEquationsAccum;
input DAE.FunctionTree inFunctionTree;
output list<BackendDAE.Equation> outEquations;
output DAE.FunctionTree outFunctionTree;
algorithm
(outEquations,outFunctionTree) := matchcontinue (inEquations)
local
DAE.FunctionTree funcs;
list<BackendDAE.Equation> rest, eqns;
BackendDAE.Equation eqn;

case {} then (listReverse(inEquationsAccum), inFunctionTree);

// equations
case eqn::rest
equation
(eqn, funcs) = differentiateEquation(eqn, inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
eqns = listAppend({eqn}, inEquationsAccum);
(eqns, funcs) = differentiateEquations(rest, inDiffwrtCref, inInputData, inDiffType, eqns, funcs);
then (eqns, funcs);

case eqn::_
equation
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {BackendDump.equationString(eqn), ComponentReference.crefStr(inDiffwrtCref)}, sourceInfo());
then
fail();
end matchcontinue;
end differentiateEquations;

public function differentiateEquation
"Differentiates an equation with respect to a cref."
input BackendDAE.Equation inEquation;
Expand All @@ -265,6 +249,11 @@ public function differentiateEquation
output BackendDAE.Equation outEquation;
output DAE.FunctionTree outFunctionTree;
algorithm
try
// Debug dump
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrEqnStr("### Differentiate equation\n ", inEquation, " w.r.t. " + ComponentReference.crefStr(inDiffwrtCref) + "\n");
end if;
(outEquation, outFunctionTree) := match inEquation
local
DAE.Exp e1_1, e2_1, e1_2, e2_2, e1, e2;
Expand Down Expand Up @@ -382,18 +371,58 @@ algorithm
(BackendDAE.IF_EQUATION(expExpLst, eqnslst, eqns, source, BackendDAE.EQUATION_ATTRIBUTES(false, eqKind)), funcs);

case BackendDAE.WHEN_EQUATION(size=size, whenEquation=whenEqn, source=source, attr=BackendDAE.EQUATION_ATTRIBUTES(kind=eqKind))
equation
equation
(whenEqn, funcs) = differentiateWhenEquations(whenEqn, inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
then
(BackendDAE.WHEN_EQUATION(size, whenEqn, source, BackendDAE.EQUATION_ATTRIBUTES(false, eqKind)), funcs);
else equation
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {BackendDump.equationString(inEquation), ComponentReference.crefStr(inDiffwrtCref)}, sourceInfo());
then fail();
end match;
// Debug dump
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrEqnStr("### Result of differentiation\n --> ", outEquation,"\n");
end if;
else
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {BackendDump.equationString(inEquation), ComponentReference.crefStr(inDiffwrtCref)}, sourceInfo());
fail();
end try;
end differentiateEquation;

else
protected function differentiateEquations
"Differentiates an equation with respect to a cref."
input list<BackendDAE.Equation> inEquations;
input DAE.ComponentRef inDiffwrtCref;
input BackendDAE.DifferentiateInputData inInputData;
input BackendDAE.DifferentiationType inDiffType;
input list<BackendDAE.Equation> inEquationsAccum;
input DAE.FunctionTree inFunctionTree;
output list<BackendDAE.Equation> outEquations;
output DAE.FunctionTree outFunctionTree;
algorithm
(outEquations,outFunctionTree) := matchcontinue (inEquations)
local
DAE.FunctionTree funcs;
list<BackendDAE.Equation> rest, eqns;
BackendDAE.Equation eqn;

case {} then (listReverse(inEquationsAccum), inFunctionTree);

// equations
case eqn::rest
equation
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {BackendDump.equationString(inEquation), ComponentReference.crefStr(inDiffwrtCref)}, sourceInfo());
(eqn, funcs) = differentiateEquation(eqn, inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
eqns = listAppend({eqn}, inEquationsAccum);
(eqns, funcs) = differentiateEquations(rest, inDiffwrtCref, inInputData, inDiffType, eqns, funcs);
then (eqns, funcs);

case eqn::_
equation
Error.addSourceMessage(Error.NON_EXISTING_DERIVATIVE, {BackendDump.equationString(eqn), ComponentReference.crefStr(inDiffwrtCref)}, sourceInfo());
then
fail();
end match;
end differentiateEquation;
end matchcontinue;
end differentiateEquations;

protected function differentiateEquationsLst
"Differentiates a list of an equation list with respect to a cref.
Expand Down Expand Up @@ -1545,7 +1574,7 @@ algorithm
end match;
end differentiateCallExp1Arg;

function createFromNCall2ArgsCall
protected function createFromNCall2ArgsCall
input String funcName;
input list<DAE.Exp> expl;
input DAE.Type tp;
Expand Down Expand Up @@ -2016,8 +2045,16 @@ algorithm
//differentiate function partial
case (e as DAE.CALL(), _, _, _, _)
equation
// Debug dump
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStr("### Differentiate call\n ", e, " w.r.t. " + ComponentReference.crefStr(inDiffwrtCref) + "\n");
end if;
(e, functions) = differentiateFunctionCallPartial(e, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter, expStack);
(e,_,_,_) = Inline.inlineExp(e,(SOME(functions),{DAE.NORM_INLINE(),DAE.NO_INLINE()}),DAE.emptyElementSource);
// Debug dump
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
BackendDump.debugStrExpStr("### result output -> ", e, " w.r.t. " + ComponentReference.crefStr(inDiffwrtCref) + "\n");
end if;
then
(e, functions);

Expand Down Expand Up @@ -2107,10 +2144,12 @@ algorithm
funcname = Util.modelicaStringToCStr(Absyn.pathString(path), false);
diffFuncData = BackendDAE.DIFFINPUTDATA(NONE(),NONE(),NONE(),NONE(),{},{},SOME(funcname));
(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter, inExpStack=expStack), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), diffFuncData, BackendDAE.GENERIC_GRADIENT(), functions);
//dexpl = listAppend(expl, dexpl);
//print("Start creation of partial Der\n");
//print("Diffed ExpList: \n");
//print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
// debug dump
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
print("### differentiated argument list:\n");
print("Diffed ExpList: \n");
print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
end if;
e = DAE.CALL(dpath,expl1,DAE.CALL_ATTR(ty,b,c,isImpure,false,dinl,tc));
e = createPartialArguments(ty, dexpl, dexplZero, expl, e);
then
Expand Down Expand Up @@ -2149,9 +2188,12 @@ algorithm
(dfunc, functions, blst) = differentiatePartialFunction(func, inDiffwrtCref, NONE(), inInputData, inDiffType, inFunctionTree, maxIter, expStack);

dpath = DAEUtil.functionName(dfunc);

// debug
//funstring = Tpl.tplString(DAEDumpTpl.dumpFunction, dfunc);
//print("\n\nDER.Function: \n" + funstring + "\n\n");
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
funstring = Tpl.tplString(DAEDumpTpl.dumpFunction, dfunc);
print("### Differentiate function: \n" + funstring + "\n\n");
end if;

functions = DAEUtil.addDaeFunction({dfunc}, functions);
// add differentiated function as function mapper
Expand All @@ -2160,32 +2202,34 @@ algorithm

// debug
// differentiate expl
//print("Finished differentiate Expression in Call.\n");
//print("DER.Function call : \n" + ExpressionDump.printExpStr(e) + "\n");
//print("Diff ExpList: \n");
//print(stringDelimitList(List.map(expl, ExpressionDump.printExpStr), ", ") + "\n");
//print("Diff ExpList Types: \n");
//print(stringDelimitList(List.map(List.map(expl, Expression.typeof), Types.printTypeStr), " | ") + "\n");

if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
print("### Detailed arguments list: \n");
print(stringDelimitList(List.map(expl, ExpressionDump.printExpStr), ", ") + "\n");
print("### and argument types: \n");
print(stringDelimitList(List.map(List.map(expl, Expression.typeof), Types.printTypeStr), " | ") + "\n");
print("### and output type: " + Types.printTypeStr(ty) + "\n");
end if;

// create differentiated call arguments
expBoolLst = List.threadTuple(expl, blst);
expBoolLst = List.filterOnTrue(expBoolLst, Util.tuple22);
expl1 = List.map(expBoolLst, Util.tuple21);
(dexpl, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter, inExpStack=expStack), inDiffwrtCref, inInputData, inDiffType, functions);
(dexplZero, functions) = List.map3Fold(expl1, function differentiateExp(maxIter=maxIter, inExpStack=expStack), DAE.CREF_IDENT("$",DAE.T_REAL_DEFAULT,{}), BackendDAE.emptyInputData, BackendDAE.GENERIC_GRADIENT(), functions);
//dexpl = listAppend(expl, dexpl);
//print("Start creation of partial Der\n");
//print("Diffed ExpList: \n");
//print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
//print(" output Type: " + Types.printTypeStr(ty) + "\n");

if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
print("### differentiated argument list:\n");
print("Diffed ExpList: \n");
print(stringDelimitList(List.map(dexpl, ExpressionDump.printExpStr), ", ") + "\n");
end if;
/*
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION_VERBOSE) then
funstring = Tpl.tplString(DAEDumpTpl.dumpFunctions, DAEUtil.getFunctionList(functions));
print("### FunctionTree: \n" + funstring + "\n\n");
end if;
*/
e = DAE.CALL(dpath,dexpl,DAE.CALL_ATTR(ty,b,false,isImpure,false,DAE.NO_INLINE(),tc));
exp = createPartialArguments(ty, dexpl, dexplZero, expl, e);

// debug
//print("Finished differentiate Expression in Call.\n");
//print("DER.Function call : \n" + ExpressionDump.printExpStr(e) + "\n");

then
(exp, functions);

Expand Down
14 changes: 8 additions & 6 deletions Compiler/FrontEnd/Expression.mo
Expand Up @@ -2244,14 +2244,15 @@ algorithm
local
Type tp;
Operator op;
DAE.Exp e1,e2,e3,e;
DAE.Exp e1,e2,e3,e,iterExp,operExp;
list<DAE.Exp> explist,exps;
Absyn.Path p;
String msg;
DAE.Type ty;
DAE.Type ty, iterTp, operTp;
list<DAE.Type> tys;
Integer i,i1,i2;
DAE.Dimension dim;
DAE.Dimensions iterdims;

case (DAE.ICONST()) then DAE.T_INTEGER_DEFAULT;
case (DAE.RCONST()) then DAE.T_REAL_DEFAULT;
Expand Down Expand Up @@ -2290,12 +2291,13 @@ algorithm
case DAE.RSUB() then inExp.ty;
case (DAE.CODE(ty = tp)) then tp;
/* array reduction with known size */
case (DAE.REDUCTION(iterators={DAE.REDUCTIONITER(exp=e,guardExp=NONE())},reductionInfo=DAE.REDUCTIONINFO(exprType=ty as DAE.T_ARRAY(dims=dim::_),path = Absyn.IDENT("array"))))
case (DAE.REDUCTION(iterators={DAE.REDUCTIONITER(exp=iterExp,guardExp=NONE())},expr = operExp, reductionInfo=DAE.REDUCTIONINFO(exprType=ty as DAE.T_ARRAY(dims=dim::_),path = Absyn.IDENT("array"))))
equation
false = dimensionKnown(dim);
DAE.T_ARRAY(dims={dim}) = typeof(e);
true = dimensionKnown(dim);
tp = liftArrayR(Types.unliftArray(Types.simplifyType(ty)),dim);
iterTp = typeof(iterExp);
operTp = typeof(operExp);
DAE.T_ARRAY(dims=iterdims) = iterTp;
tp = Types.liftTypeWithDims(operTp, iterdims);
then tp;
case (DAE.REDUCTION(reductionInfo=DAE.REDUCTIONINFO(exprType=ty)))
then Types.simplifyType(ty);
Expand Down
4 changes: 3 additions & 1 deletion Compiler/FrontEnd/InstVar.mo
Expand Up @@ -1167,7 +1167,9 @@ algorithm
// Propagate the final prefix from the modifier.
//fin = InstUtil.propagateModFinal(mod, fin);

attr = stripVarAttrDirection(cr, ih, inState, inPrefix, attr);
if not Flags.getConfigBool(Flags.USE_LOCAL_DIRECTION) then
attr = stripVarAttrDirection(cr, ih, inState, inPrefix, attr);
end if;

// Propagate prefixes to any elements inside this components if it's a
// structured component.
Expand Down

0 comments on commit 3c084f6

Please sign in to comment.