Skip to content

Commit 2cbdc11

Browse files
authored
Attempt to fix generation of CALL expressions on lhs of assignments and solved equations. (#592)
- Do not expand records when differentiating functions. - Try simplifying a RSUB expression before differentiating. - Print error when finding a CALL expression on the lhs.
1 parent a83b51b commit 2cbdc11

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

OMCompiler/Compiler/BackEnd/Differentiate.mo

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -704,21 +704,29 @@ algorithm
704704
// differentiate rsub
705705
case e1 as DAE.RSUB()
706706
algorithm
707-
(res1, functionTree) := differentiateExp(e1.exp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
708-
if not referenceEq(e1.exp, res1) then
709-
try
710-
(expl, strLst) := match res1
711-
case DAE.RECORD(exps=expl,comp=strLst) then (expl, strLst);
712-
case DAE.CALL(path=p1,expLst=expl,attr=DAE.CALL_ATTR(ty=DAE.T_COMPLEX(complexClassType=ClassInf.RECORD(path=p2), varLst=varLst)))
713-
guard AbsynUtil.pathEqual(p1,p2)
714-
then (expl, list(v.name for v in varLst));
715-
end match;
716-
res := listGet(expl, List.position1OnTrue(strLst, stringEq, e1.fieldName));
717-
else
718-
e1.exp := res1;
719-
(res,_) := ExpressionSimplify.simplify1(e1);
720-
end try;
721-
end if;
707+
// Try simplifying first.
708+
(res, b) := ExpressionSimplify.simplify(e1);
709+
if b then
710+
(res, functionTree) := differentiateExp(res, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
711+
else
712+
(res1, functionTree) := differentiateExp(e1.exp, inDiffwrtCref, inInputData, inDiffType, inFunctionTree, maxIter-1);
713+
// This might not be needed anymore. If it is simplifiable
714+
// Then it would have been simplified above.
715+
if not referenceEq(e1.exp, res1) then
716+
try
717+
(expl, strLst) := match res1
718+
case DAE.RECORD(exps=expl,comp=strLst) then (expl, strLst);
719+
case DAE.CALL(path=p1,expLst=expl,attr=DAE.CALL_ATTR(ty=DAE.T_COMPLEX(complexClassType=ClassInf.RECORD(path=p2), varLst=varLst)))
720+
guard AbsynUtil.pathEqual(p1,p2)
721+
then (expl, list(v.name for v in varLst));
722+
end match;
723+
res := listGet(expl, List.position1OnTrue(strLst, stringEq, e1.fieldName));
724+
else
725+
e1.exp := res1;
726+
(res,_) := ExpressionSimplify.simplify1(e1);
727+
end try;
728+
end if;
729+
end if;
722730
then (res, functionTree);
723731

724732
// differentiate tuple
@@ -1073,6 +1081,16 @@ algorithm
10731081
// This part contains general rules for differentation crefs
10741082
//
10751083

1084+
// case for records without expanding the record
1085+
case ((DAE.CREF(componentRef = cr,ty = tp as DAE.T_COMPLEX(varLst=varLst,complexClassType=ClassInf.RECORD(path)))), _, BackendDAE.DIFFINPUTDATA(matrixName=SOME(matrixName)), BackendDAE.DIFFERENTIATION_FUNCTION(), _)
1086+
equation
1087+
cr = ComponentReference.prependStringCref(BackendDAE.functionDerivativeNamePrefix, cr);
1088+
cr = ComponentReference.prependStringCref(matrixName, cr);
1089+
1090+
res = Expression.makeCrefExp(cr, tp);
1091+
then
1092+
(res, inFunctionTree);
1093+
10761094
// case for Records
10771095
case ((DAE.CREF(componentRef = cr,ty = tp as DAE.T_COMPLEX(varLst=varLst,complexClassType=ClassInf.RECORD(path)))), _, _, _, _)
10781096
equation

OMCompiler/Compiler/Template/CodegenCFunctions.tpl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,14 +2847,22 @@ template algStmtAssignRecord(DAE.Statement stmt, Context context, Text &preExp,
28472847
// The right hand side might be a call so we create a tmp var here and assign it. If the rhs is not
28482848
// a call this is an uncessary copy. however, we can live with it since it is not a deep copy and the
28492849
// c compiler should just be able to optimzie it away.
2850-
let tmp_rec = tempDecl(rec_typename,&varDecls)
2851-
let rhs = daeExp(rhs_exp, context, &preExp, &varDecls, &auxFunction)
2852-
let vars = args |> arg => ( ", &(" + daeExp(arg, context, &preExp, &varDecls, &auxFunction) + ")" )
2853-
<<
2854-
<%tmp_rec%> = <%rhs%>;
2855-
<%rec_typename%>_copy_to_vars(<%tmp_rec%><%vars%>);
2856-
>>
2857-
// <%error(sourceInfo(), 'Left hand side of an assignment is a call expression. <%ExpressionDumpTpl.dumpExp(exp1,"\"")%> = <%ExpressionDumpTpl.dumpExp(exp,"\"")%>')%>
2850+
// let tmp_rec = tempDecl(rec_typename,&varDecls)
2851+
// let rhs = daeExp(rhs_exp, context, &preExp, &varDecls, &auxFunction)
2852+
// let vars = args |> arg => ( ", &(" + daeExp(arg, context, &preExp, &varDecls, &auxFunction) + ")" )
2853+
// <<
2854+
// <%tmp_rec%> = <%rhs%>;
2855+
// <%rec_typename%>_copy_to_vars(<%tmp_rec%><%vars%>);
2856+
// >>
2857+
// let rhs_exp_str = daeExp(rhs_exp, context, &preExp, &varDecls, &auxFunction)
2858+
// let tmp_rec = tempDecl(rec_typename, &varDecls)
2859+
// let assigns = splitRhsForRecordAssignmentToMemberAssignments(args, ty, tmp_rec)
2860+
// |> stmt => algStatement(stmt, context, &varDecls, &auxFunction)
2861+
// <<
2862+
// <%tmp_rec%> = <%rhs_exp_str%>;
2863+
// <%assigns%>
2864+
// >>
2865+
error(sourceInfo(), 'Left hand side of an assignment is a call expression. <%ExpressionDumpTpl.dumpExp(exp1,"\"")%> = <%ExpressionDumpTpl.dumpExp(exp,"\"")%>')
28582866
case STMT_ASSIGN(exp1=RECORD(), type_ = ty as T_COMPLEX(complexClassType=RECORD(__))) then
28592867
error(sourceInfo(), 'Left hand side of an assignment is a record expression. <%ExpressionDumpTpl.dumpExp(exp1,"\"")%>')
28602868
case STMT_ASSIGN() then
@@ -2878,18 +2886,17 @@ template assignRhsExpToRecordCrefSimContext(ComponentRef lhs_cref, Exp rhs_exp,
28782886
::=
28792887
let lhs = contextCref(lhs_cref, context, &auxFunction)
28802888
let rec_typename = expTypeShort(rec_type)
2881-
let &rhs_exp_str = buffer ""
28822889

28832890
match rhs_exp
28842891
case CREF(componentRef = cr) then
2885-
let &rhs_exp_str += contextCref(cr, context, auxFunction)
2892+
let rhs_exp_str = contextCref(cr, context, auxFunction)
28862893
let assigns = splitRecordAssignmentToMemberAssignments(lhs_cref, rec_type, rhs_exp_str)
28872894
|> stmt => algStatement(stmt, context, &varDecls, &auxFunction)
28882895
<<
28892896
<%assigns%>
28902897
>>
28912898
else
2892-
let &rhs_exp_str += daeExp(rhs_exp, context, &preExp, &varDecls, &auxFunction)
2899+
let rhs_exp_str = daeExp(rhs_exp, context, &preExp, &varDecls, &auxFunction)
28932900
let tmp_rec = tempDecl(rec_typename, &varDecls)
28942901
let assigns = splitRecordAssignmentToMemberAssignments(lhs_cref, rec_type, tmp_rec)
28952902
|> stmt => algStatement(stmt, context, &varDecls, &auxFunction)

0 commit comments

Comments
 (0)