Skip to content

Commit

Permalink
[janitor] Cleanup old backend differentiateExp (#10410)
Browse files Browse the repository at this point in the history
  • Loading branch information
phannebohm committed Mar 21, 2023
1 parent 16cbb66 commit 7968170
Showing 1 changed file with 63 additions and 105 deletions.
168 changes: 63 additions & 105 deletions OMCompiler/Compiler/BackEnd/Differentiate.mo
Expand Up @@ -593,35 +593,46 @@ algorithm
DAE.ComponentRef cref;

// types that are not differentiated
case DAE.BCONST() then (inExp, inFunctionTree);
case DAE.SCONST() then (inExp, inFunctionTree);
case DAE.BCONST() then (inExp, inFunctionTree);
case DAE.CLKCONST() then (inExp, inFunctionTree);
case DAE.ENUM_LITERAL() then (inExp, inFunctionTree);

// constants => results in zero
case DAE.ICONST() then (DAE.ICONST(0), inFunctionTree);
case DAE.RCONST() then (DAE.RCONST(0.0), inFunctionTree);

case DAE.RECORD(path = p, exps = expl, comp = strLst, ty=tp)
algorithm
sub := {};
functionTree := inFunctionTree;
for e in expl loop
(e1, functionTree) := differentiateExp(e,inDiffwrtCref, inInputData, inDiffType, functionTree, maxIter);
sub := e1 :: sub;
end for;
then (DAE.RECORD(p, listReverse(sub), strLst, tp), functionTree);

// differentiate cref
case DAE.CREF(componentRef=cref, ty=tp) equation

case DAE.CREF(componentRef=cref, ty=tp) algorithm
if ComponentReference.isStartCref(cref) then
// differentiate start value
res = Expression.makeConstZero(tp);
functionTree = inFunctionTree;
res := Expression.makeConstZero(tp);
functionTree := inFunctionTree;
else
(res, functionTree) = differentiateCrefs(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res, functionTree) := differentiateCrefs(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
end if;
then (res, functionTree);

case DAE.BINARY() algorithm
(res, functionTree) := differentiateBinary(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
res := ExpressionSimplify.simplifyBinaryExp(res);
then (res, functionTree);

case DAE.UNARY(operator=op, exp=e1) algorithm
(res, functionTree) := differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
res := DAE.UNARY(op, res);
res := ExpressionSimplify.simplifyUnaryExp(res);
then (res, functionTree);

// boolean expression, e.g. relation, are left as they are
case DAE.LBINARY() then (inExp, inFunctionTree);
case DAE.LUNARY() then (inExp, inFunctionTree);
case DAE.RELATION() then (inExp, inFunctionTree);

case DAE.IFEXP(expCond=e1, expThen=e2, expElse=e3) algorithm
(res1, functionTree) := differentiateExp(e2, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res2, functionTree) := differentiateExp(e3, inDiffwrtCref, inInputData, inDiffType, functionTree, maxIter-1);
res := DAE.IFEXP(e1, res1, res2);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

// differentiate homotopy
Expand All @@ -641,72 +652,55 @@ algorithm
ticket: #5595
*/
case DAE.CALL(path=p as Absyn.IDENT(name="semiLinear"), expLst={e1, e2, e3}, attr=attr)
guard(Expression.expHasCref(e2, inDiffwrtCref) or Expression.expHasCref(e3, inDiffwrtCref))
guard(Expression.expHasCref(e2, inDiffwrtCref) or Expression.expHasCref(e3, inDiffwrtCref))
then fail();

// differentiate call
case DAE.CALL() equation

(res, functionTree) = differentiateCalls(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res,_) = ExpressionSimplify.simplify1(res);

then (res, functionTree);

// differentiate binary
case DAE.BINARY() equation

(res, functionTree) = differentiateBinary(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res) = ExpressionSimplify.simplifyBinaryExp(res);

case DAE.CALL() algorithm
(res, functionTree) := differentiateCalls(inExp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

// differentiate operator
case DAE.UNARY(operator=op, exp=e1) equation

(res, functionTree) = differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);

res = DAE.UNARY(op,res);
(res) = ExpressionSimplify.simplifyUnaryExp(res);

case DAE.RECORD(path = p, exps = expl, comp = strLst, ty=tp) algorithm
sub := {};
functionTree := inFunctionTree;
for e in expl loop
(e1, functionTree) := differentiateExp(e, inDiffwrtCref, inInputData, inDiffType, functionTree, maxIter);
sub := e1 :: sub;
end for;
then (DAE.RECORD(p, listReverse(sub), strLst, tp), functionTree);

case DAE.ARRAY(ty=tp, scalar=b, array=expl) algorithm
(expl, functionTree) := List.map3Fold(expl, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
res := DAE.ARRAY(tp, b, expl);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

// differentiate cast
case DAE.CAST(ty=tp, exp=e1) equation

(res, functionTree) = differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res,_) = ExpressionSimplify.simplify1(res);

then (DAE.CAST(tp, res), functionTree);

// differentiate asub
case DAE.ASUB(exp=e1, sub=sub) equation

(res1, functionTree) = differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);

res = Expression.makeASUB(res1,sub);
(res,_) = ExpressionSimplify.simplify1(res);

case DAE.MATRIX(ty=tp, integer=i, matrix=matrix) algorithm
(dmatrix, functionTree) := List.map3FoldList(matrix, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
res := DAE.MATRIX(tp, i, dmatrix);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

case DAE.ARRAY(ty=tp, scalar=b, array=expl) equation

(expl, functionTree) = List.map3Fold(expl, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);

res = DAE.ARRAY(tp, b, expl);
(res,_) = ExpressionSimplify.simplify1(res);
case DAE.RANGE() then (inExp, inFunctionTree);

case DAE.TUPLE(PR=expl) algorithm
(expl, functionTree) := List.map3Fold(expl, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);
res := DAE.TUPLE(expl);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

case DAE.MATRIX(ty=tp, integer=i, matrix=matrix) equation

(dmatrix, functionTree) = List.map3FoldList(matrix, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);

res = DAE.MATRIX(tp, i, dmatrix);
(res,_) = ExpressionSimplify.simplify1(res);
case DAE.CAST(ty=tp, exp=e1) algorithm
(res, functionTree) := differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res, _) := ExpressionSimplify.simplify1(res);
then (DAE.CAST(tp, res), functionTree);

case DAE.ASUB(exp=e1, sub=sub) algorithm
(res1, functionTree) := differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
res := Expression.makeASUB(res1,sub);
(res, _) := ExpressionSimplify.simplify1(res);
then (res, functionTree);

// differentiate tsub
case DAE.TSUB(exp=e1, ix=i, ty=tp)
algorithm
(res1, functionTree) := differentiateExp(e1, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
Expand All @@ -719,8 +713,6 @@ algorithm
end if;
then (res, functionTree);


// differentiate rsub
case e1 as DAE.RSUB()
algorithm
// Try simplifying first.
Expand Down Expand Up @@ -748,41 +740,7 @@ algorithm
end if;
then (res, functionTree);

// differentiate tuple
case DAE.TUPLE(PR=expl) equation

(expl, functionTree) = List.map3Fold(expl, function differentiateExp(maxIter=maxIter-1), inDiffwrtCref, inInputData, inDiffType, inFunctionTree);

res = DAE.TUPLE(expl);
(res,_) = ExpressionSimplify.simplify1(res);

then (res, functionTree);

case DAE.IFEXP(expCond=e1, expThen=e2, expElse=e3) equation

(res1, functionTree) = differentiateExp(e2, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
(res2, functionTree) = differentiateExp(e3, inDiffwrtCref, inInputData, inDiffType, functionTree, maxIter-1);

res = DAE.IFEXP(e1, res1, res2);
(res,_) = ExpressionSimplify.simplify1(res);

then (res, functionTree);

// boolean expression, e.g. relation, are left as they are
case DAE.RELATION()
then (inExp, inFunctionTree);

case DAE.LBINARY()
then (inExp, inFunctionTree);

case DAE.LUNARY()
then (inExp, inFunctionTree);

case DAE.SIZE()
then (inExp, inFunctionTree);

case DAE.RANGE()
then (inExp, inFunctionTree);
case DAE.SIZE() then (inExp, inFunctionTree);

case DAE.REDUCTION()
algorithm
Expand Down

0 comments on commit 7968170

Please sign in to comment.