Skip to content

Commit

Permalink
Nb when initial (#11481)
Browse files Browse the repository at this point in the history
* [NB] backenddaeinfo spaces

* [NB] update when equation removal for initialization
  • Loading branch information
kabdelhak committed Oct 31, 2023
1 parent 4d4dbc7 commit c0ade4e
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 23 deletions.
26 changes: 26 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Expand Up @@ -42,6 +42,7 @@ public

// New Frontend imports
import Algorithm = NFAlgorithm;
import BackendDAE = NBackendDAE;
import BackendExtension = NFBackendExtension;
import Binding = NFBinding;
import Call = NFCall;
Expand Down Expand Up @@ -356,6 +357,11 @@ public
output Integer size = product(i for i in 1 :: sizes(iter));
end size;

function dimensions
input Iterator iter;
output list<Dimension> dims = list(Dimension.fromInteger(s) for s in sizes(iter));
end dimensions;

function createLocationReplacements
"adds replacements rules for a single frame location"
input Iterator iter "iterator to replace";
Expand Down Expand Up @@ -887,6 +893,18 @@ public
Equation.createName(eq, idx, str);
end makeAssignment;

function makeAlgorithm
input list<Statement> stmts;
input Boolean init;
output Pointer<Equation> eqn;
protected
Algorithm alg;
algorithm
alg := Algorithm.ALGORITHM(stmts, {}, {}, InstNode.EMPTY_NODE(), DAE.emptyElementSource);
alg := Algorithm.setInputsOutputs(alg);
eqn := BackendDAE.lowerAlgorithm(alg, init);
end makeAlgorithm;

function forEquationToString
input Iterator iter "the iterator variable(s)";
input list<Equation> body "iterated equations";
Expand Down Expand Up @@ -1504,6 +1522,7 @@ public
case SCALAR_EQUATION() then eq.ty;
case ARRAY_EQUATION() then eq.ty;
case RECORD_EQUATION() then eq.ty;
case FOR_EQUATION() then Type.liftArrayRightList(getType(List.first(eq.body)), Iterator.dimensions(eq.iter));
else Type.REAL(); // TODO: WRONG there should not be an else case
end match;
end getType;
Expand Down Expand Up @@ -1545,6 +1564,13 @@ public
end match;
end getForFrames;

function isDummy
input Equation eqn;
output Boolean b;
algorithm
b := match eqn case DUMMY_EQUATION() then true; else false; end match;
end isDummy;

function isDiscrete
input Pointer<Equation> eqn;
output Boolean b;
Expand Down
5 changes: 4 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Expand Up @@ -530,9 +530,12 @@ public

case MULTI_COMPONENT() algorithm
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set));
dependencies := list(ComponentRef.stripIteratorSubscripts(dep) for dep in dependencies);
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));
for var in comp.vars loop
updateDependencyMap(BVariable.getVarName(var), dependencies, map, jacType);
for cref in ComponentRef.scalarizeAll(BVariable.getVarName(var)) loop
updateDependencyMap(cref, dependencies, map, jacType);
end for;
end for;
then ();

Expand Down
36 changes: 18 additions & 18 deletions OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -1065,7 +1065,7 @@ protected
end match;
end lowerWhenBranchStatement;
function lowerAlgorithm
public function lowerAlgorithm
input Algorithm alg;
input Boolean init;
output Pointer<Equation> eq;
Expand All @@ -1087,7 +1087,7 @@ protected
eq := Pointer.create(Equation.ALGORITHM(size, alg, alg.source, DAE.EXPAND(), attr));
end lowerAlgorithm;
function lowerEquationAttributes
protected function lowerEquationAttributes
input Type ty;
input Boolean init;
output EquationAttributes attr;
Expand Down Expand Up @@ -1307,21 +1307,21 @@ public
Error.addCompilerNotification(
"Partition statistics after passing the back-end:\n"
+ "* Number of ODE partitions: ..................... " + p_ode + "\n"
+ "* Number of algebraic partitions: ............... " + p_alg + "\n"
+ "* Number of ODE event partitions: ............... " + p_ode_e + "\n"
+ "* Number of algebraic event partitions: ......... " + p_alg_e + "\n"
+ "* Number of clocked partitions: ................. " + p_clk + "\n"
+ "* Number of initial partitions: ................. " + p_ini + "\n"
+ "* Number of initial(lambda=0) partitions: ....... " + p_ini_0);
+ " * Number of ODE partitions: ..................... " + p_ode + "\n"
+ " * Number of algebraic partitions: ............... " + p_alg + "\n"
+ " * Number of ODE event partitions: ............... " + p_ode_e + "\n"
+ " * Number of algebraic event partitions: ......... " + p_alg_e + "\n"
+ " * Number of clocked partitions: ................. " + p_clk + "\n"
+ " * Number of initial partitions: ................. " + p_ini + "\n"
+ " * Number of initial(lambda=0) partitions: ....... " + p_ini_0);
Error.addCompilerNotification(
"Variable statistics after passing the back-end:\n"
+ "* Number of states: ............................. " + states + "\n"
+ "* Number of discrete states: .................... " + discrete_states + "\n"
+ "* Number of clocked states: ..................... " + clocked_states + "\n"
+ "* Number of discrete variables: ................. " + discretes + "\n"
+ "* Number of top-level inputs: ................... " + inputs);
+ " * Number of states: ............................. " + states + "\n"
+ " * Number of discrete states: .................... " + discrete_states + "\n"
+ " * Number of clocked states: ..................... " + clocked_states + "\n"
+ " * Number of discrete variables: ................. " + discretes + "\n"
+ " * Number of top-level inputs: ................... " + inputs);
// collect strong component info simulation
strongcomponentinfo("Simulation", {bdae.ode, bdae.algebraic, bdae.ode_event, bdae.alg_event});
Expand Down Expand Up @@ -1357,10 +1357,10 @@ public
Error.addCompilerNotification(
"[" + phase + "] Strong Component statistics after passing the back-end:\n"
+ "* Number of single strong components: ........... " + single_sc + "\n"
+ "* Number of multi strong components: ............ " + multi_sc + "\n"
+ "* Number of for-loop strong components: ......... " + for_sc + "\n"
+ "* Number of algebraic-loop strong components: ... " + alg_sc);
+ " * Number of single strong components: ........... " + single_sc + "\n"
+ " * Number of multi strong components: ............ " + multi_sc + "\n"
+ " * Number of for-loop strong components: ......... " + for_sc + "\n"
+ " * Number of algebraic-loop strong components: ... " + alg_sc);
end strongcomponentinfo;
annotation(__OpenModelica_Interface="backend");
Expand Down
155 changes: 152 additions & 3 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo
Expand Up @@ -37,6 +37,7 @@ encapsulated package NBInitialization

protected
// NF imports
import Algorithm = NFAlgorithm;
import BackendExtension = NFBackendExtension;
import Call = NFCall;
import ComponentRef = NFComponentRef;
Expand All @@ -46,14 +47,16 @@ protected
import NFFunction.Function;
import NFFlatten.{FunctionTree, FunctionTreeImpl};
import NFInstNode.InstNode;
import Operator = NFOperator;
import Statement = NFStatement;
import Subscript = NFSubscript;
import Type = NFType;
import Variable = NFVariable;

// Backend imports
import BackendDAE = NBackendDAE;
import BEquation = NBEquation;
import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes, EquationKind, Iterator, WhenEquationBody};
import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes, EquationKind, Iterator, WhenEquationBody, WhenStatement, IfEquationBody};
import BVariable = NBVariable;
import NBVariable.{VariablePointer, VariablePointers, VarData};
import Causalize = NBCausalize;
Expand Down Expand Up @@ -294,7 +297,7 @@ public
var_ptr := Slice.getT(state);
// make unique iterators for the new for-loop
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
dims := Type.arrayDims(ComponentRef.getSubscriptedType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
iter_crefs := list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators);
iter_crefs := list(BackendDAE.lowerIteratorCref(iter) for iter in iter_crefs);
Expand Down Expand Up @@ -377,7 +380,7 @@ public
pre := BVariable.getPrePost(var_ptr);
if Util.isSome(pre) then
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
dims := Type.arrayDims(ComponentRef.getSubscriptedType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
frames := List.zip(list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators), ranges);

Expand Down Expand Up @@ -502,5 +505,151 @@ public
end match;
end cleanupHomotopy;

function removeWhenEquation
"this function checks if an equation has to be removed before initialization.
true for: when branch without condition initial()"
input output Equation eqn;
algorithm
eqn := match eqn
local
Equation new_eqn;
list<Statement> stmts;
Option<IfEquationBody> if_body;

// reduce the body of for equations
case Equation.FOR_EQUATION() algorithm
eqn.body := list(removeWhenEquation(b) for b in eqn.body);
then if List.all(eqn.body, Equation.isDummy) then Equation.DUMMY_EQUATION() else eqn;

// reduce the body of when equations
case Equation.WHEN_EQUATION() algorithm
stmts := removeWhenEquationBody(SOME(eqn.body));
if not listEmpty(stmts) then
new_eqn := Pointer.access(Equation.makeAlgorithm(stmts, true));
new_eqn := Equation.setResidualVar(new_eqn, Equation.getResidualVar(Pointer.create(eqn)));
else
new_eqn := Equation.DUMMY_EQUATION();
end if;
then new_eqn;

// reduce the body of if equations
case Equation.IF_EQUATION() algorithm
eqn.body := removeWhenEquationIfBody(eqn.body);
eqn.size := IfEquationBody.size(eqn.body);
then if eqn.size > 0 then eqn else Equation.DUMMY_EQUATION();

// reduce the body of algorithms
case Equation.ALGORITHM() algorithm
stmts := removeWhenEquationAlgorithmBody(eqn.alg.statements);
if not listEmpty(stmts) then
new_eqn := Pointer.access(Equation.makeAlgorithm(stmts, true));
new_eqn := Equation.setResidualVar(new_eqn, Equation.getResidualVar(Pointer.create(eqn)));
else
new_eqn := Equation.DUMMY_EQUATION();
end if;
then new_eqn;

else eqn;
end match;
end removeWhenEquation;

function removeWhenEquationBody
input Option<WhenEquationBody> body_opt;
output list<Statement> stmts;
algorithm
stmts := match body_opt
local
WhenEquationBody body;

case SOME(body) algorithm
if isInitialCall(body.condition) then
// this is kept, return the statements
stmts := list(WhenStatement.toStatement(st) for st in body.when_stmts);
else
// dig deeper
stmts := removeWhenEquationBody(body.else_when);
end if;
then stmts;

else {};
end match;
end removeWhenEquationBody;

function removeWhenEquationIfBody
input output IfEquationBody body;
algorithm
body.then_eqns := list(Pointer.apply(e, removeWhenEquation) for e in body.then_eqns);
if Util.isSome(body.else_if) then
body.else_if := SOME(removeWhenEquationIfBody(Util.getOption(body.else_if)));
end if;
end removeWhenEquationIfBody;

function removeWhenEquationAlgorithmBody
input list<Statement> in_stmts;
output list<Statement> out_stmts;
protected
list<list<Statement>> stmts = {};
algorithm
for stmt in listReverse(in_stmts) loop
stmts := removeWhenEquationStatement(stmt) :: stmts;
end for;
out_stmts := List.flatten(stmts);
end removeWhenEquationAlgorithmBody;

function removeWhenEquationStatement
input Statement stmt;
output list<Statement> out_stmts = {};
algorithm
out_stmts := match stmt
local
Expression cond;
list<Statement> stmts;
list<list<Statement>> stmts_acc = {};

case Statement.WHEN() algorithm
for tpl in stmt.branches loop
(cond, stmts) := tpl;
if isInitialCall(cond) then
out_stmts := stmts;
break;
end if;
end for;
then out_stmts;

case Statement.FOR() algorithm
for body_stmt in listReverse(stmt.body) loop
stmts_acc := removeWhenEquationStatement(body_stmt) :: stmts_acc;
end for;
stmts := List.flatten(stmts_acc);
if not listEmpty(stmts) then
stmt.body := stmts;
out_stmts := {stmt};
else
out_stmts := {};
end if;
then out_stmts;

else {stmt};
end match;
end removeWhenEquationStatement;

function isInitialCall
"checks if the expression is an initial call or can be simplified to be one.
ToDo: better apprach is to replace all initial calls with true and see if the expression can be simplified to true.
do this once ExpressionSimplify is mature enough"
input Expression condition;
output Boolean b;
algorithm
b := match condition
// it's an initial call -> true;
case Expression.CALL() then Call.isNamed(condition.call, "initial");
// its an "or" expression, check if either argument is an initial call
case Expression.LBINARY(operator = Operator.OPERATOR(op = NFOperator.Op.OR))
then isInitialCall(condition.exp1) or isInitialCall(condition.exp2);
// not an initial call. Ignore "and" constructs
else false;
end match;
end isInitialCall;

annotation(__OpenModelica_Interface="backend");
end NBInitialization;
Expand Up @@ -91,7 +91,8 @@ public
case (System.SystemType.INI, BackendDAE.MAIN(varData = BVariable.VAR_DATA_SIM(initials = variables), eqData = BEquation.EQ_DATA_SIM(initials = equations)))
algorithm
// ToDo: check if when equation is active during initialization
equations := EquationPointers.mapRemovePtr(equations, Equation.isWhenEquation);
equations := EquationPointers.map(equations, Initialization.removeWhenEquation);
equations := EquationPointers.compress(equations);
bdae.init := list(sys for sys guard(not System.System.isEmpty(sys)) in partitioningNone(systemType, variables, equations));
then bdae;

Expand Down
2 changes: 2 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFExpression.mo
Expand Up @@ -2298,6 +2298,8 @@ public
Type.toDAE(exp.ty),
Type.toDAE(Type.FUNCTION(fn, NFType.FunctionType.FUNCTIONAL_VARIABLE)));

case MUTABLE() then toDAE(Mutable.access(exp.exp));

else
algorithm
Error.assertion(false, getInstanceName() + " got unknown expression '" + toString(exp) + "'", sourceInfo());
Expand Down

0 comments on commit c0ade4e

Please sign in to comment.