Skip to content

Commit

Permalink
[NB] better handling of size 1 arrays (#10778)
Browse files Browse the repository at this point in the history
- inlining of size 1 for loops
 - removal of size 1 array subscripts just before solving
 - better dumps
  • Loading branch information
kabdelhak committed May 31, 2023
1 parent 00dd494 commit 4a16421
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 30 deletions.
8 changes: 6 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Expand Up @@ -63,6 +63,7 @@ public

// New Backend imports
import Evaluation = NBEvaluation;
import Inline = NBInline;
import Replacements = NBReplacements;
import StrongComponent = NBStrongComponent;
import Solve = NBSolve;
Expand Down Expand Up @@ -821,6 +822,7 @@ public
input EquationAttributes attr;
output Pointer<Equation> eq;
protected
Equation e;
Type ty = ComponentRef.getSubscriptedType(lhs, true);
algorithm
if listLength(frames) == 0 then
Expand All @@ -843,13 +845,15 @@ public
));
end if;
else
eq := Pointer.create(FOR_EQUATION(
e := FOR_EQUATION(
ty = ComponentRef.nodeType(lhs),
iter = Iterator.fromFrames(frames),
body = {SCALAR_EQUATION(ty, Expression.fromCref(lhs), rhs, DAE.emptyElementSource, attr)}, // this can also be an array?
source = DAE.emptyElementSource,
attr = attr
));
);
// inline if it has size 1
eq := Pointer.create(Inline.inlineForEquation(e));
end if;
Equation.createName(eq, idx, str);
end makeAssignment;
Expand Down
14 changes: 9 additions & 5 deletions OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Expand Up @@ -656,6 +656,8 @@ public
list<Slice<VariablePointer>> comp_vars;
list<Slice<EquationPointer>> comp_eqns;
Tearing tearingSet;
Slice<VariablePointer> var_slice;
Slice<EquationPointer> eqn_slice;

// Size 1 strong component
// - case 1: sliced equation because of sliced variable
Expand All @@ -669,7 +671,13 @@ public
if size > 1 or Equation.isForEquation(eqn) then
// case 1: create the scalar variable and make sliced equation
cref := VariablePointers.varSlice(vars, var_scal_idx, mapping);
comp := SLICED_COMPONENT(cref, Slice.SLICE(var, {}), Slice.SLICE(eqn, {}), NBSolve.Status.UNPROCESSED);
try
({var_slice}, {eqn_slice}) := getLoopVarsAndEqns(comp_indices, eqn_to_var, mapping, vars, eqns);
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because single indices did not turn out to be single components."});
fail();
end try;
comp := SLICED_COMPONENT(cref, var_slice, eqn_slice, NBSolve.Status.UNPROCESSED);
else
// case 2: just create a single strong component
comp := match Pointer.access(eqn)
Expand All @@ -685,10 +693,6 @@ public
case _ algorithm
(comp_vars, comp_eqns) := getLoopVarsAndEqns(comp_indices, eqn_to_var, mapping, vars, eqns);
comp := match (comp_vars, comp_eqns)
local
Slice<VariablePointer> var_slice;
Slice<EquationPointer> eqn_slice;

case ({var_slice}, {eqn_slice}) guard(not Equation.isForEquation(Slice.getT(eqn_slice))) algorithm
if Slice.isFull(var_slice) then
comp := SINGLE_COMPONENT(
Expand Down
8 changes: 4 additions & 4 deletions OMCompiler/Compiler/NBackEnd/Classes/NBSystem.mo
Expand Up @@ -215,25 +215,25 @@ public
end if;
end getLoopResiduals;

function mapEquations
function mapEqn
input output System system;
input MapFunc func;
partial function MapFunc
input output BEquation.Equation e;
end MapFunc;
algorithm
system.equations := EquationPointers.map(system.equations, func);
end mapEquations;
end mapEqn;

function mapExpressions
function mapExp
input output System system;
input MapFunc func;
partial function MapFunc
input output Expression e;
end MapFunc;
algorithm
system.equations := EquationPointers.mapExp(system.equations, func);
end mapExpressions;
end mapExp;

function systemTypeString
input SystemType systemType;
Expand Down
6 changes: 5 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -756,7 +756,11 @@ protected
);

// merge iterators of each for equation instead of having nested loops (for {i in 1:10, j in 1:3, k in 1:5})
Pointer.update(body_elem_ptr, Equation.mergeIterators(body_elem));
body_elem := Equation.mergeIterators(body_elem);
// inline if size 1
body_elem := Inline.inlineForEquation(body_elem);

Pointer.update(body_elem_ptr, body_elem);
result := body_elem_ptr :: result;
end for;
else
Expand Down
24 changes: 12 additions & 12 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo
Expand Up @@ -346,29 +346,29 @@ public
case BackendDAE.MAIN() algorithm

// initial() -> false
bdae.ode := list(System.mapEquations(sys, function cleanupInitialCall(init = false)) for sys in bdae.ode);
bdae.algebraic := list(System.mapEquations(sys, function cleanupInitialCall(init = false)) for sys in bdae.algebraic);
bdae.ode_event := list(System.mapEquations(sys, function cleanupInitialCall(init = false)) for sys in bdae.ode_event);
bdae.alg_event := list(System.mapEquations(sys, function cleanupInitialCall(init = false)) for sys in bdae.alg_event);
bdae.ode := list(System.mapEqn(sys, function cleanupInitialCall(init = false)) for sys in bdae.ode);
bdae.algebraic := list(System.mapEqn(sys, function cleanupInitialCall(init = false)) for sys in bdae.algebraic);
bdae.ode_event := list(System.mapEqn(sys, function cleanupInitialCall(init = false)) for sys in bdae.ode_event);
bdae.alg_event := list(System.mapEqn(sys, function cleanupInitialCall(init = false)) for sys in bdae.alg_event);
if Util.isSome(bdae.dae) then
bdae.dae := SOME(list(System.mapEquations(sys, function cleanupInitialCall(init = false)) for sys in Util.getOption(bdae.dae)));
bdae.dae := SOME(list(System.mapEqn(sys, function cleanupInitialCall(init = false)) for sys in Util.getOption(bdae.dae)));
end if;
// initial() -> true
bdae.init := list(System.mapEquations(sys, function cleanupInitialCall(init = true)) for sys in bdae.init);
bdae.init := list(System.mapEqn(sys, function cleanupInitialCall(init = true)) for sys in bdae.init);

// homotopy(actual, simplified) -> actual
bdae.ode := list(System.mapExpressions(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.ode);
bdae.algebraic := list(System.mapExpressions(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.algebraic);
bdae.ode_event := list(System.mapExpressions(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.ode_event);
bdae.alg_event := list(System.mapExpressions(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.alg_event);
bdae.ode := list(System.mapExp(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.ode);
bdae.algebraic := list(System.mapExp(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.algebraic);
bdae.ode_event := list(System.mapExp(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.ode_event);
bdae.alg_event := list(System.mapExp(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in bdae.alg_event);
if Util.isSome(bdae.dae) then
bdae.dae := SOME(list(System.mapExpressions(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in Util.getOption(bdae.dae)));
bdae.dae := SOME(list(System.mapExp(sys, function cleanupHomotopy(init = false, hasHom = hasHom)) for sys in Util.getOption(bdae.dae)));
end if;

// create init_0 if homotopy call exists.
if Pointer.access(hasHom) then
bdae.init_0 := SOME(list(System.clone(sys, false) for sys in bdae.init));
bdae.init_0 := SOME(list(System.mapExpressions(sys, function cleanupHomotopy(init = true, hasHom = hasHom)) for sys in Util.getOption(bdae.init_0)));
bdae.init_0 := SOME(list(System.mapExp(sys, function cleanupHomotopy(init = true, hasHom = hasHom)) for sys in Util.getOption(bdae.init_0)));
end if;

then bdae;
Expand Down
32 changes: 31 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBInline.mo
Expand Up @@ -45,6 +45,7 @@ protected

// NF imports
import Call = NFCall;
import ComponentRef = NFComponentRef;
import Expression = NFExpression;
import NFFunction.Function;
import NFFlatten.FunctionTree;
Expand All @@ -54,7 +55,7 @@ protected
import Module = NBModule;
import BackendDAE = NBackendDAE;
import BEquation = NBEquation;
import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes};
import NBEquation.{Equation, EquationPointers, EqData, EquationAttributes, Iterator};
import Replacements = NBReplacements;
import NBVariable.{VariablePointers, VarData};

Expand Down Expand Up @@ -84,6 +85,34 @@ public
// =========================================================================
// TYPES, UNIONTYPES AND MEMBER FUNCTIONS
// =========================================================================
function inlineForEquation
input output Equation eqn;
algorithm
eqn := match eqn
local
Equation new_eqn;
UnorderedMap<ComponentRef, Expression> replacements "replacement map for iterator crefs";
list<ComponentRef> names;
list<Expression> ranges;
ComponentRef name;
Expression range;
Integer start;

case Equation.FOR_EQUATION(body = {new_eqn}) guard(Equation.size(Pointer.create(eqn)) == 1) algorithm
replacements := UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual);
(names, ranges) := Iterator.getFrames(eqn.iter);
for tpl in List.zip(names, ranges) loop
(name, range) := tpl;
(start, _, _) := Expression.getIntegerRange(range);
UnorderedMap.add(name, Expression.INTEGER(start), replacements);
end for;
new_eqn := Equation.map(new_eqn, function Replacements.applySimpleExp(replacements = replacements));
then new_eqn;

else eqn;
end match;
end inlineForEquation;

protected
function inline extends Module.inlineInterface;
protected
Expand Down Expand Up @@ -252,5 +281,6 @@ protected
end match;
end inlineRecordConstructorElements;


annotation(__OpenModelica_Interface="backend");
end NBInline;
3 changes: 3 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Modules/3_Post/NBSolve.mo
Expand Up @@ -51,6 +51,7 @@ public

// backend imports
import BackendDAE = NBackendDAE;
import BackendUtil = NBBackendUtil;
import Causalize = NBCausalize;
import Differentiate = NBDifferentiate;
import NBEquation.{Equation, EquationPointer, EquationPointers, EqData, IfEquationBody, SlicingStatus};
Expand Down Expand Up @@ -133,6 +134,8 @@ public
ComponentRef name;
list<Pointer<Equation>> sliced_eqns;
algorithm
// remove size one array subscripts for solving
system := System.mapExp(system, BackendUtil.removeSizeOneArraySubscriptsExp);
if Util.isSome(system.strongComponents) then
for comp in Util.getOption(system.strongComponents) loop
if UnorderedMap.contains(comp, duplicate_map) then
Expand Down
26 changes: 26 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Util/NBBackendUtil.mo
Expand Up @@ -257,5 +257,31 @@ public
end if;
end isContinuousFold;

function removeSizeOneArraySubscriptsExp
input output Expression exp;
algorithm
exp := match exp
case Expression.CREF() algorithm
exp.cref := removeSizeOneArraySubscriptsCref(exp.cref);
then exp;
else exp;
end match;
end removeSizeOneArraySubscriptsExp;

function removeSizeOneArraySubscriptsCref
input output ComponentRef cref;
algorithm
cref := match cref
case ComponentRef.CREF() algorithm
if Type.isArray(cref.ty) and Type.sizeOf(cref.ty) == 1 then
cref.subscripts := {};
end if;
cref.restCref := removeSizeOneArraySubscriptsCref(cref.restCref);
then cref;

else cref;
end match;
end removeSizeOneArraySubscriptsCref;

annotation(__OpenModelica_Interface="backend");
end NBBackendUtil;
12 changes: 7 additions & 5 deletions OMCompiler/Compiler/NSimCode/NSimStrongComponent.mo
Expand Up @@ -254,7 +254,7 @@ public
case ALGORITHM() then str + "(" + intString(blck.index) + ") Algorithm\n" + Statement.toStringList(blck.stmts, str) + "\n";
case INVERSE_ALGORITHM() then str + "(" + intString(blck.index) + ") Inverse Algorithm\n" + Statement.toStringList(blck.stmts, str) + "\n";
case IF() then str + "(" + intString(blck.index) + ") " + List.toString(blck.branches, function ifTplStr(str = str), "", str, str + "else ", str + "end if;\n");
case WHEN() then str + "(" + intString(blck.index) + ") " + whenString(blck.conditions, blck.when_stmts, blck.else_when);
case WHEN() then str + "(" + intString(blck.index) + ") " + whenString(blck.conditions, blck.when_stmts, blck.else_when, str);
case LINEAR() then str + "(" + intString(blck.system.index) + ") " + LinearSystem.toString(blck.system, str);
case NONLINEAR() then str + "(" + intString(blck.system.index) + ") " + NonlinearSystem.toString(blck.system, str);
case HYBRID() then str + "(" + intString(blck.index) + ") Hybrid\n"; // ToDo!
Expand Down Expand Up @@ -1168,14 +1168,16 @@ public
input list<ComponentRef> conditions;
input list<WhenStatement> when_stmts;
input Option<Block> else_when;
output String str = "";
input output String str = "";
protected
String indent = str;
algorithm
str := "when " + List.toString(conditions, ComponentRef.toString) + "\n" +
List.toString(when_stmts, function WhenStatement.toString(str = "\t"), "", "", "\n") + "\n";
List.toString(when_stmts, function WhenStatement.toString(str = indent + "\t"), "", "", "\n", "") + "\n";
if Util.isSome(else_when) then
str := str + "else" + toString(Util.getOption(else_when));
str := str + indent + "else" + toString(Util.getOption(else_when));
else
str := str + "end when;";
str := str + indent + "end when;\n";
end if;
end whenString;

Expand Down
14 changes: 14 additions & 0 deletions OMCompiler/Compiler/SimCode/SimCodeUtil.mo
Expand Up @@ -9136,6 +9136,20 @@ algorithm
s := s+"\n";
then s;

case SimCode.SES_GENERIC_ASSIGN()
algorithm
s := intString(eqSysIn.index) +": "+ " (SES_GENERIC_ASSIGN) " + " call index: " + intString(eqSysIn.call_index) + "\n";
s := s + "\tindices: " + List.toString(eqSysIn.scal_indices, intString, "", "{", ", ", "}", true, 10) + "\n";
then s;

case SimCode.SES_ENTWINED_ASSIGN()
algorithm
s := intString(eqSysIn.index) +": "+ " (SES_ENTWINED_ASSIGN)\n";
s := s + "\tcall order: " + List.toString(eqSysIn.call_order, intString, "", "{", ", ", "}", true, 10) + "\n";
s := s + List.toString(eqSysIn.single_calls, simEqSystemString, "", "\t", "\n", "");
s := s + "\n";
then s;

else
then
"SOMETHING DIFFERENT\n";
Expand Down

0 comments on commit 4a16421

Please sign in to comment.