From c0ade4ee7be17b5ed3cb172b934a0049abef48f7 Mon Sep 17 00:00:00 2001 From: kabdelhak <38032125+kabdelhak@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:42:59 +0100 Subject: [PATCH] Nb when initial (#11481) * [NB] backenddaeinfo spaces * [NB] update when equation removal for initialization --- .../Compiler/NBackEnd/Classes/NBEquation.mo | 26 +++ .../NBackEnd/Classes/NBStrongComponent.mo | 5 +- .../Compiler/NBackEnd/Classes/NBackendDAE.mo | 36 ++-- .../Modules/1_Main/NBInitialization.mo | 155 +++++++++++++++++- .../NBackEnd/Modules/1_Main/NBPartitioning.mo | 3 +- .../Compiler/NFFrontEnd/NFExpression.mo | 2 + 6 files changed, 204 insertions(+), 23 deletions(-) diff --git a/OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo b/OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo index b5389b0b866..614d6e92d52 100644 --- a/OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo +++ b/OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo @@ -42,6 +42,7 @@ public // New Frontend imports import Algorithm = NFAlgorithm; + import BackendDAE = NBackendDAE; import BackendExtension = NFBackendExtension; import Binding = NFBinding; import Call = NFCall; @@ -356,6 +357,11 @@ public output Integer size = product(i for i in 1 :: sizes(iter)); end size; + function dimensions + input Iterator iter; + output list dims = list(Dimension.fromInteger(s) for s in sizes(iter)); + end dimensions; + function createLocationReplacements "adds replacements rules for a single frame location" input Iterator iter "iterator to replace"; @@ -887,6 +893,18 @@ public Equation.createName(eq, idx, str); end makeAssignment; + function makeAlgorithm + input list stmts; + input Boolean init; + output Pointer eqn; + protected + Algorithm alg; + algorithm + alg := Algorithm.ALGORITHM(stmts, {}, {}, InstNode.EMPTY_NODE(), DAE.emptyElementSource); + alg := Algorithm.setInputsOutputs(alg); + eqn := BackendDAE.lowerAlgorithm(alg, init); + end makeAlgorithm; + function forEquationToString input Iterator iter "the iterator variable(s)"; input list body "iterated equations"; @@ -1504,6 +1522,7 @@ public case SCALAR_EQUATION() then eq.ty; case ARRAY_EQUATION() then eq.ty; case RECORD_EQUATION() then eq.ty; + case FOR_EQUATION() then Type.liftArrayRightList(getType(List.first(eq.body)), Iterator.dimensions(eq.iter)); else Type.REAL(); // TODO: WRONG there should not be an else case end match; end getType; @@ -1545,6 +1564,13 @@ public end match; end getForFrames; + function isDummy + input Equation eqn; + output Boolean b; + algorithm + b := match eqn case DUMMY_EQUATION() then true; else false; end match; + end isDummy; + function isDiscrete input Pointer eqn; output Boolean b; diff --git a/OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo b/OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo index a8044a0a62b..8c2472e2cb4 100644 --- a/OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo +++ b/OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo @@ -530,9 +530,12 @@ public case MULTI_COMPONENT() algorithm dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set)); + dependencies := list(ComponentRef.stripIteratorSubscripts(dep) for dep in dependencies); dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies)); for var in comp.vars loop - updateDependencyMap(BVariable.getVarName(var), dependencies, map, jacType); + for cref in ComponentRef.scalarizeAll(BVariable.getVarName(var)) loop + updateDependencyMap(cref, dependencies, map, jacType); + end for; end for; then (); diff --git a/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo b/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo index fda0fc834af..a22eec47833 100644 --- a/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo +++ b/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo @@ -1065,7 +1065,7 @@ protected end match; end lowerWhenBranchStatement; - function lowerAlgorithm + public function lowerAlgorithm input Algorithm alg; input Boolean init; output Pointer eq; @@ -1087,7 +1087,7 @@ protected eq := Pointer.create(Equation.ALGORITHM(size, alg, alg.source, DAE.EXPAND(), attr)); end lowerAlgorithm; - function lowerEquationAttributes + protected function lowerEquationAttributes input Type ty; input Boolean init; output EquationAttributes attr; @@ -1307,21 +1307,21 @@ public Error.addCompilerNotification( "Partition statistics after passing the back-end:\n" - + "* Number of ODE partitions: ..................... " + p_ode + "\n" - + "* Number of algebraic partitions: ............... " + p_alg + "\n" - + "* Number of ODE event partitions: ............... " + p_ode_e + "\n" - + "* Number of algebraic event partitions: ......... " + p_alg_e + "\n" - + "* Number of clocked partitions: ................. " + p_clk + "\n" - + "* Number of initial partitions: ................. " + p_ini + "\n" - + "* Number of initial(lambda=0) partitions: ....... " + p_ini_0); + + " * Number of ODE partitions: ..................... " + p_ode + "\n" + + " * Number of algebraic partitions: ............... " + p_alg + "\n" + + " * Number of ODE event partitions: ............... " + p_ode_e + "\n" + + " * Number of algebraic event partitions: ......... " + p_alg_e + "\n" + + " * Number of clocked partitions: ................. " + p_clk + "\n" + + " * Number of initial partitions: ................. " + p_ini + "\n" + + " * Number of initial(lambda=0) partitions: ....... " + p_ini_0); Error.addCompilerNotification( "Variable statistics after passing the back-end:\n" - + "* Number of states: ............................. " + states + "\n" - + "* Number of discrete states: .................... " + discrete_states + "\n" - + "* Number of clocked states: ..................... " + clocked_states + "\n" - + "* Number of discrete variables: ................. " + discretes + "\n" - + "* Number of top-level inputs: ................... " + inputs); + + " * Number of states: ............................. " + states + "\n" + + " * Number of discrete states: .................... " + discrete_states + "\n" + + " * Number of clocked states: ..................... " + clocked_states + "\n" + + " * Number of discrete variables: ................. " + discretes + "\n" + + " * Number of top-level inputs: ................... " + inputs); // collect strong component info simulation strongcomponentinfo("Simulation", {bdae.ode, bdae.algebraic, bdae.ode_event, bdae.alg_event}); @@ -1357,10 +1357,10 @@ public Error.addCompilerNotification( "[" + phase + "] Strong Component statistics after passing the back-end:\n" - + "* Number of single strong components: ........... " + single_sc + "\n" - + "* Number of multi strong components: ............ " + multi_sc + "\n" - + "* Number of for-loop strong components: ......... " + for_sc + "\n" - + "* Number of algebraic-loop strong components: ... " + alg_sc); + + " * Number of single strong components: ........... " + single_sc + "\n" + + " * Number of multi strong components: ............ " + multi_sc + "\n" + + " * Number of for-loop strong components: ......... " + for_sc + "\n" + + " * Number of algebraic-loop strong components: ... " + alg_sc); end strongcomponentinfo; annotation(__OpenModelica_Interface="backend"); diff --git a/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo b/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo index 0e531f52e80..3ea48e85eea 100644 --- a/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo +++ b/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo @@ -37,6 +37,7 @@ encapsulated package NBInitialization protected // NF imports + import Algorithm = NFAlgorithm; import BackendExtension = NFBackendExtension; import Call = NFCall; import ComponentRef = NFComponentRef; @@ -46,6 +47,8 @@ protected import NFFunction.Function; import NFFlatten.{FunctionTree, FunctionTreeImpl}; import NFInstNode.InstNode; + import Operator = NFOperator; + import Statement = NFStatement; import Subscript = NFSubscript; import Type = NFType; import Variable = NFVariable; @@ -53,7 +56,7 @@ protected // Backend imports import BackendDAE = NBackendDAE; import BEquation = NBEquation; - import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes, EquationKind, Iterator, WhenEquationBody}; + import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes, EquationKind, Iterator, WhenEquationBody, WhenStatement, IfEquationBody}; import BVariable = NBVariable; import NBVariable.{VariablePointer, VariablePointers, VarData}; import Causalize = NBCausalize; @@ -294,7 +297,7 @@ public var_ptr := Slice.getT(state); // make unique iterators for the new for-loop name := BVariable.getVarName(var_ptr); - dims := Type.arrayDims(ComponentRef.nodeType(name)); + dims := Type.arrayDims(ComponentRef.getSubscriptedType(name)); (iterators, ranges, subscripts) := Flatten.makeIterators(name, dims); iter_crefs := list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators); iter_crefs := list(BackendDAE.lowerIteratorCref(iter) for iter in iter_crefs); @@ -377,7 +380,7 @@ public pre := BVariable.getPrePost(var_ptr); if Util.isSome(pre) then name := BVariable.getVarName(var_ptr); - dims := Type.arrayDims(ComponentRef.nodeType(name)); + dims := Type.arrayDims(ComponentRef.getSubscriptedType(name)); (iterators, ranges, subscripts) := Flatten.makeIterators(name, dims); frames := List.zip(list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators), ranges); @@ -502,5 +505,151 @@ public end match; end cleanupHomotopy; + function removeWhenEquation + "this function checks if an equation has to be removed before initialization. + true for: when branch without condition initial()" + input output Equation eqn; + algorithm + eqn := match eqn + local + Equation new_eqn; + list stmts; + Option if_body; + + // reduce the body of for equations + case Equation.FOR_EQUATION() algorithm + eqn.body := list(removeWhenEquation(b) for b in eqn.body); + then if List.all(eqn.body, Equation.isDummy) then Equation.DUMMY_EQUATION() else eqn; + + // reduce the body of when equations + case Equation.WHEN_EQUATION() algorithm + stmts := removeWhenEquationBody(SOME(eqn.body)); + if not listEmpty(stmts) then + new_eqn := Pointer.access(Equation.makeAlgorithm(stmts, true)); + new_eqn := Equation.setResidualVar(new_eqn, Equation.getResidualVar(Pointer.create(eqn))); + else + new_eqn := Equation.DUMMY_EQUATION(); + end if; + then new_eqn; + + // reduce the body of if equations + case Equation.IF_EQUATION() algorithm + eqn.body := removeWhenEquationIfBody(eqn.body); + eqn.size := IfEquationBody.size(eqn.body); + then if eqn.size > 0 then eqn else Equation.DUMMY_EQUATION(); + + // reduce the body of algorithms + case Equation.ALGORITHM() algorithm + stmts := removeWhenEquationAlgorithmBody(eqn.alg.statements); + if not listEmpty(stmts) then + new_eqn := Pointer.access(Equation.makeAlgorithm(stmts, true)); + new_eqn := Equation.setResidualVar(new_eqn, Equation.getResidualVar(Pointer.create(eqn))); + else + new_eqn := Equation.DUMMY_EQUATION(); + end if; + then new_eqn; + + else eqn; + end match; + end removeWhenEquation; + + function removeWhenEquationBody + input Option body_opt; + output list stmts; + algorithm + stmts := match body_opt + local + WhenEquationBody body; + + case SOME(body) algorithm + if isInitialCall(body.condition) then + // this is kept, return the statements + stmts := list(WhenStatement.toStatement(st) for st in body.when_stmts); + else + // dig deeper + stmts := removeWhenEquationBody(body.else_when); + end if; + then stmts; + + else {}; + end match; + end removeWhenEquationBody; + + function removeWhenEquationIfBody + input output IfEquationBody body; + algorithm + body.then_eqns := list(Pointer.apply(e, removeWhenEquation) for e in body.then_eqns); + if Util.isSome(body.else_if) then + body.else_if := SOME(removeWhenEquationIfBody(Util.getOption(body.else_if))); + end if; + end removeWhenEquationIfBody; + + function removeWhenEquationAlgorithmBody + input list in_stmts; + output list out_stmts; + protected + list> stmts = {}; + algorithm + for stmt in listReverse(in_stmts) loop + stmts := removeWhenEquationStatement(stmt) :: stmts; + end for; + out_stmts := List.flatten(stmts); + end removeWhenEquationAlgorithmBody; + + function removeWhenEquationStatement + input Statement stmt; + output list out_stmts = {}; + algorithm + out_stmts := match stmt + local + Expression cond; + list stmts; + list> stmts_acc = {}; + + case Statement.WHEN() algorithm + for tpl in stmt.branches loop + (cond, stmts) := tpl; + if isInitialCall(cond) then + out_stmts := stmts; + break; + end if; + end for; + then out_stmts; + + case Statement.FOR() algorithm + for body_stmt in listReverse(stmt.body) loop + stmts_acc := removeWhenEquationStatement(body_stmt) :: stmts_acc; + end for; + stmts := List.flatten(stmts_acc); + if not listEmpty(stmts) then + stmt.body := stmts; + out_stmts := {stmt}; + else + out_stmts := {}; + end if; + then out_stmts; + + else {stmt}; + end match; + end removeWhenEquationStatement; + + function isInitialCall + "checks if the expression is an initial call or can be simplified to be one. + ToDo: better apprach is to replace all initial calls with true and see if the expression can be simplified to true. + do this once ExpressionSimplify is mature enough" + input Expression condition; + output Boolean b; + algorithm + b := match condition + // it's an initial call -> true; + case Expression.CALL() then Call.isNamed(condition.call, "initial"); + // its an "or" expression, check if either argument is an initial call + case Expression.LBINARY(operator = Operator.OPERATOR(op = NFOperator.Op.OR)) + then isInitialCall(condition.exp1) or isInitialCall(condition.exp2); + // not an initial call. Ignore "and" constructs + else false; + end match; + end isInitialCall; + annotation(__OpenModelica_Interface="backend"); end NBInitialization; diff --git a/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBPartitioning.mo b/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBPartitioning.mo index b2ec6c70854..7bee5b7820c 100644 --- a/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBPartitioning.mo +++ b/OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBPartitioning.mo @@ -91,7 +91,8 @@ public case (System.SystemType.INI, BackendDAE.MAIN(varData = BVariable.VAR_DATA_SIM(initials = variables), eqData = BEquation.EQ_DATA_SIM(initials = equations))) algorithm // ToDo: check if when equation is active during initialization - equations := EquationPointers.mapRemovePtr(equations, Equation.isWhenEquation); + equations := EquationPointers.map(equations, Initialization.removeWhenEquation); + equations := EquationPointers.compress(equations); bdae.init := list(sys for sys guard(not System.System.isEmpty(sys)) in partitioningNone(systemType, variables, equations)); then bdae; diff --git a/OMCompiler/Compiler/NFFrontEnd/NFExpression.mo b/OMCompiler/Compiler/NFFrontEnd/NFExpression.mo index 12016b3a5bb..01d489e0f58 100644 --- a/OMCompiler/Compiler/NFFrontEnd/NFExpression.mo +++ b/OMCompiler/Compiler/NFFrontEnd/NFExpression.mo @@ -2298,6 +2298,8 @@ public Type.toDAE(exp.ty), Type.toDAE(Type.FUNCTION(fn, NFType.FunctionType.FUNCTIONAL_VARIABLE))); + case MUTABLE() then toDAE(Mutable.access(exp.exp)); + else algorithm Error.assertion(false, getInstanceName() + " got unknown expression '" + toString(exp) + "'", sourceInfo());