Skip to content

Commit

Permalink
[NB] implement Equation.isEqual (#11864)
Browse files Browse the repository at this point in the history
- used to properly track StrongComponent Alias
 - before we only used the names which lead to wrong strong component alias in the case of homotopy optimization
  • Loading branch information
kabdelhak committed Jan 23, 2024
1 parent 9e816f5 commit 00dc993
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 5 deletions.
90 changes: 89 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Expand Up @@ -798,6 +798,53 @@ public
output Boolean b = ComponentRef.isEqual(getEqnName(eqn1), getEqnName(eqn2));
end equalName;

function isEqualPtrTpl
input tuple<EquationPointer, EquationPointer> tpl;
output Boolean b;
protected
EquationPointer eqn1, eqn2;
algorithm
(eqn1, eqn2) := tpl;
b := isEqualPtr(eqn1, eqn2);
end isEqualPtrTpl;

function isEqualPtr
input Pointer<Equation> eqn1;
input Pointer<Equation> eqn2;
output Boolean b = isEqual(Pointer.access(eqn1), Pointer.access(eqn2));
end isEqualPtr;

function isEqualTpl
input tuple<Equation, Equation> tpl;
output Boolean b;
protected
Equation eqn1, eqn2;
algorithm
(eqn1, eqn2) := tpl;
b := isEqual(eqn1, eqn2);
end isEqualTpl;

function isEqual
input Equation eqn1;
input Equation eqn2;
output Boolean b;
algorithm
b := match (eqn1, eqn2)
case (SCALAR_EQUATION(), SCALAR_EQUATION()) then Expression.isEqual(eqn1.lhs, eqn2.lhs) and Expression.isEqual(eqn1.rhs, eqn2.rhs);
case (ARRAY_EQUATION(), ARRAY_EQUATION()) then Expression.isEqual(eqn1.lhs, eqn2.lhs) and Expression.isEqual(eqn1.rhs, eqn2.rhs);
case (RECORD_EQUATION(), RECORD_EQUATION()) then Expression.isEqual(eqn1.lhs, eqn2.lhs) and Expression.isEqual(eqn1.rhs, eqn2.rhs);
// ToDo: This is wrong! implement the Algorithm.isEqual!
// case (ALGORITHM(), ALGORITHM()) then Algorithm.isEqual(eqn1.alg, eqn2.alg);
case (ALGORITHM(), ALGORITHM()) then equalName(Pointer.create(eqn1), Pointer.create(eqn2));
case (IF_EQUATION(), IF_EQUATION()) then IfEquationBody.isEqual(eqn1.body, eqn2.body);
case (FOR_EQUATION(), FOR_EQUATION()) then Iterator.isEqual(eqn1.iter, eqn2.iter) and List.all(List.zip(eqn1.body, eqn2.body), isEqualTpl);
case (WHEN_EQUATION(), WHEN_EQUATION()) then WhenEquationBody.isEqual(eqn1.body, eqn2.body);
case (AUX_EQUATION(), AUX_EQUATION()) then BVariable.equalName(eqn1.auxiliary, eqn2.auxiliary) and Util.optionEqual(eqn1.body, eqn2.body, isEqual);
case (DUMMY_EQUATION(), DUMMY_EQUATION()) then true;
else false;
end match;
end isEqual;

function getEqnName
input Pointer<Equation> eqn;
output ComponentRef name;
Expand Down Expand Up @@ -2163,7 +2210,6 @@ public
ifBody.else_if := SOME(else_if);
end if;
end if;

end map;

function size
Expand All @@ -2172,6 +2218,14 @@ public
output Integer size = sum(Equation.size(eqn) for eqn in body.then_eqns);
end size;

function isEqual
input IfEquationBody body1;
input IfEquationBody body2;
output Boolean b;
algorithm
b := List.all(List.zip(body1.then_eqns, body2.then_eqns), Equation.isEqualPtrTpl) and Util.optionEqual(body1.else_if, body2.else_if, isEqual);
end isEqual;

function createNames
input IfEquationBody body;
input Pointer<Integer> idx;
Expand Down Expand Up @@ -2307,6 +2361,14 @@ public
output Integer s = sum(WhenStatement.size(stmt) for stmt in body.when_stmts);
end size;

function isEqual
input WhenEquationBody body1;
input WhenEquationBody body2;
output Boolean b;
algorithm
b := List.all(List.zip(body1.when_stmts, body2.when_stmts), WhenStatement.isEqualTpl) and Util.optionEqual(body1.else_when, body2.else_when, isEqual);
end isEqual;

function getBodyAttributes
"gets all conditions crefs as a list (has to be applied AFTER Event module)"
input WhenEquationBody body;
Expand Down Expand Up @@ -2632,6 +2694,32 @@ public
end match;
end toString;

function isEqualTpl
input tuple<WhenStatement, WhenStatement> tpl;
output Boolean b;
protected
WhenStatement stmt1;
WhenStatement stmt2;
algorithm
(stmt1, stmt2) := tpl;
b := isEqual(stmt1, stmt2);
end isEqualTpl;

function isEqual
input WhenStatement stmt1;
input WhenStatement stmt2;
output Boolean b;
algorithm
b := match (stmt1, stmt2)
case (ASSIGN(), ASSIGN()) then Expression.isEqual(stmt1.lhs, stmt2.lhs) and Expression.isEqual(stmt1.rhs, stmt2.rhs);
case (REINIT(), REINIT()) then ComponentRef.isEqual(stmt1.stateVar, stmt2.stateVar) and Expression.isEqual(stmt1.value, stmt2.value);
case (ASSERT(), ASSERT()) then Expression.isEqual(stmt1.condition, stmt2.condition) and Expression.isEqual(stmt1.message, stmt2.message) and Expression.isEqual(stmt1.level, stmt2.level);
case (TERMINATE(), TERMINATE()) then Expression.isEqual(stmt1.message, stmt2.message);
case (NORETCALL(), NORETCALL()) then Expression.isEqual(stmt1.exp, stmt2.exp);
else false;
end match;
end isEqual;

function toStatement
input WhenStatement wstmt;
output Statement stmt;
Expand Down
8 changes: 4 additions & 4 deletions OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Expand Up @@ -297,10 +297,10 @@ public
output Boolean b;
algorithm
b := match(comp1, comp2)
case (SINGLE_COMPONENT(), SINGLE_COMPONENT()) then BVariable.equalName(comp1.var, comp2.var) and Equation.equalName(comp1.eqn, comp2.eqn);
case (MULTI_COMPONENT(), MULTI_COMPONENT()) then Equation.equalName(comp1.eqn, comp2.eqn);
case (SLICED_COMPONENT(), SLICED_COMPONENT()) then ComponentRef.isEqual(comp1.var_cref, comp2.var_cref) and Slice.isEqual(comp1.eqn, comp2.eqn, Equation.equalName);
case (GENERIC_COMPONENT(), GENERIC_COMPONENT()) then Slice.isEqual(comp1.eqn, comp2.eqn, Equation.equalName);
case (SINGLE_COMPONENT(), SINGLE_COMPONENT()) then BVariable.equalName(comp1.var, comp2.var) and Equation.isEqualPtr(comp1.eqn, comp2.eqn);
case (MULTI_COMPONENT(), MULTI_COMPONENT()) then Equation.isEqualPtr(comp1.eqn, comp2.eqn);
case (SLICED_COMPONENT(), SLICED_COMPONENT()) then ComponentRef.isEqual(comp1.var_cref, comp2.var_cref) and Slice.isEqual(comp1.eqn, comp2.eqn, Equation.isEqualPtr);
case (GENERIC_COMPONENT(), GENERIC_COMPONENT()) then Slice.isEqual(comp1.eqn, comp2.eqn, Equation.isEqualPtr);
case (ENTWINED_COMPONENT(), ENTWINED_COMPONENT()) then List.isEqualOnTrue(comp1.entwined_slices, comp2.entwined_slices, isEqual);
case (ALGEBRAIC_LOOP(), ALGEBRAIC_LOOP()) then Tearing.isEqual(comp1.strict, comp2.strict);
case (ALIAS(), ALIAS()) then AliasInfo.isEqual(comp1.aliasInfo, comp2.aliasInfo);
Expand Down

0 comments on commit 00dc993

Please sign in to comment.