Skip to content

Commit dd9ae73

Browse files
authored
[NB] update array handling (#8869)
- update scalarizing - deal with model subscripts (just as any other subscript) - deal with scalar subscripts (include 1 as subscript for scalars instead of empty) - update entwining of equations - allow intersection to be empty - implement several dumping functions
1 parent 75428a3 commit dd9ae73

File tree

12 files changed

+420
-234
lines changed

12 files changed

+420
-234
lines changed

OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,8 @@ public
274274
rest2 := intersectRest(iter2.name, start2, step2, stop2, start_max-step2, stop_min+step2);
275275
then (intersection, rest1, rest2);
276276

277-
else algorithm
278-
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because only single iterators with equal step can be intersected:\n"
279-
+ Iterator.toString(iter1) + "\n" + Iterator.toString(iter2) + "\n"});
280-
then fail();
277+
// cannot intersect
278+
else (EMPTY(), (iter1, EMPTY()), (EMPTY(), iter2));
281279
end match;
282280
end intersect;
283281

@@ -1536,20 +1534,26 @@ public
15361534

15371535
function entwine
15381536
input list<Equation> eqn_lst "has to be for-loops with combinable ranges";
1537+
input Integer nesting_level = 0;
15391538
output list<Equation> entwined = {} "returns a single for-loop on top level if it is possible";
15401539
protected
15411540
Equation eqn1, eqn2, next;
15421541
list<Equation> rest, tmp;
15431542
Iterator intersection, rest1_left, rest1_right, rest2_left, rest2_right;
1543+
String shift = StringUtil.repeat(" ", nesting_level);
15441544
algorithm
1545+
if Flags.isSet(Flags.DUMP_SLICE) then
1546+
print(shift + "[" + intString(nesting_level) + "] ### Entwining following equations:\n"
1547+
+ List.toString(eqn_lst, function Equation.toString(str = shift + " "), "", "", "\n", "\n\n"));
1548+
end if;
15451549
eqn1 :: rest := eqn_lst;
15461550
while not listEmpty(rest) loop
15471551
eqn2 :: rest := rest;
15481552
eqn1 := match (eqn1, eqn2)
15491553

15501554
// entwine body if possible - equal iterator -> no intersecting
15511555
case (FOR_EQUATION(), FOR_EQUATION()) guard(Iterator.isEqual(eqn1.iter, eqn2.iter)) algorithm
1552-
eqn1.body := entwine(listAppend(eqn1.body, eqn2.body));
1556+
eqn1.body := entwine(listAppend(eqn1.body, eqn2.body), nesting_level + 1);
15531557
then eqn1;
15541558

15551559
// if the iterators are not equal, they have to be intersected and the respective rests have to be handled
@@ -1566,7 +1570,7 @@ public
15661570
tmp := FOR_EQUATION(
15671571
ty = eqn1.ty,
15681572
iter = intersection,
1569-
body = entwine(listAppend(eqn1.body, eqn2.body)),
1573+
body = entwine(listAppend(eqn1.body, eqn2.body), nesting_level + 1),
15701574
source = eqn1.source,
15711575
attr = eqn1.attr
15721576
) :: tmp;
@@ -1589,6 +1593,10 @@ public
15891593
end match;
15901594
end while;
15911595
entwined := listReverse(eqn1 :: entwined);
1596+
if Flags.isSet(Flags.DUMP_SLICE) then
1597+
print(shift + "[" + intString(nesting_level) + "] +++ Result of entwining:\n"
1598+
+ List.toString(entwined, function Equation.toString(str = shift + " "), "", "", "\n", "\n\n"));
1599+
end if;
15921600
end entwine;
15931601

15941602
function slice

OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,9 +1475,16 @@ public
14751475
"Returns true if the variable is in the variable pointer array."
14761476
input Pointer<Variable> var;
14771477
input VariablePointers variables;
1478-
output Boolean b = getVarIndex(variables, getVarName(var)) > 0;
1478+
output Boolean b = containsCref(getVarName(var), variables);
14791479
end contains;
14801480

1481+
function containsCref
1482+
"Returns true if a variable with this name is in the variable pointer array."
1483+
input ComponentRef cref;
1484+
input VariablePointers variables;
1485+
output Boolean b = getVarIndex(variables, cref) > 0;
1486+
end containsCref;
1487+
14811488
function getVarNames
14821489
"returns a list of crefs representing the names of all variables"
14831490
input VariablePointers variables;
@@ -1578,7 +1585,7 @@ public
15781585
// flatten potential arrays
15791586
if Type.isArray(var.ty) then
15801587
flattened := true;
1581-
scalar_vars := Scalarize.scalarizeVariable(var);
1588+
scalar_vars := Scalarize.scalarizeBackendVariable(var);
15821589
else
15831590
scalar_vars := {Pointer.access(var_ptr)};
15841591
end if;

OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ protected
364364
VariablePointers variables, unknowns, knowns, initials, auxiliaries, aliasVars, nonTrivialAlias;
365365
VariablePointers states, derivatives, algebraics, discretes, previous;
366366
VariablePointers parameters, constants;
367+
Pointer<list<Pointer<Variable>>> binding_iter_lst = Pointer.create({});
367368
Boolean scalarized = Flags.isSet(Flags.NF_SCALARIZE);
368369
algorithm
369370
// instantiate variable data (with one more space for time variable);
@@ -458,8 +459,11 @@ protected
458459
parameters := VariablePointers.fromList(parameters_lst, scalarized);
459460
constants := VariablePointers.fromList(constants_lst, scalarized);
460461

461-
/* lower the variable bindings */
462-
VariablePointers.map(variables, function lowerVariableBinding(variables = variables));
462+
/* lower the variable bindings and add binding iterators */
463+
variables := VariablePointers.map(variables, function collectVariableBindingIterators(variables = variables, binding_iter_lst = binding_iter_lst));
464+
variables := VariablePointers.addList(Pointer.access(binding_iter_lst), variables);
465+
knowns := VariablePointers.addList(Pointer.access(binding_iter_lst), knowns);
466+
variables := VariablePointers.map(variables, function lowerVariableBinding(variables = variables));
463467

464468
/* create variable data */
465469
variableData := BVariable.VAR_DATA_SIM(variables, unknowns, knowns, initials, auxiliaries, aliasVars, nonTrivialAlias,
@@ -549,17 +553,32 @@ protected
549553
end match;
550554
end lowerVariableKind;
551555

556+
function collectVariableBindingIterators
557+
input output Variable var;
558+
input VariablePointers variables;
559+
input Pointer<list<Pointer<Variable>>> binding_iter_lst;
560+
algorithm
561+
_ := match var
562+
local
563+
Binding binding;
564+
case Variable.VARIABLE(binding = binding as Binding.TYPED_BINDING()) algorithm
565+
// collect all iterators (only locally known) so that they have a respective variable
566+
Expression.map(binding.bindingExp, function collectBindingIterators(variables = variables, binding_iter_lst = binding_iter_lst));
567+
then ();
568+
else ();
569+
end match;
570+
end collectVariableBindingIterators;
571+
552572
function lowerVariableBinding
553573
input output Variable var;
554574
input VariablePointers variables;
555575
algorithm
556576
var := match var
557577
local
558578
Binding binding;
559-
case Variable.VARIABLE(binding = binding as Binding.TYPED_BINDING())
560-
algorithm
561-
binding.bindingExp := Expression.map(binding.bindingExp, function lowerComponentReferenceExp(variables = variables));
562-
var.binding := binding;
579+
case Variable.VARIABLE(binding = binding as Binding.TYPED_BINDING()) algorithm
580+
binding.bindingExp := Expression.map(binding.bindingExp, function lowerComponentReferenceExp(variables = variables));
581+
var.binding := binding;
563582
then var;
564583
else var;
565584
end match;
@@ -1106,6 +1125,23 @@ protected
11061125
end try;
11071126
end lowerComponentReference;
11081127

1128+
function collectBindingIterators
1129+
"collects all iterators in bindings and creates variables for them.
1130+
in bindings they are only known locally but they still need a respective variable"
1131+
input output Expression exp;
1132+
input VariablePointers variables;
1133+
input Pointer<list<Pointer<Variable>>> binding_iter_lst;
1134+
algorithm
1135+
_ := match exp
1136+
local
1137+
ComponentRef cref;
1138+
case Expression.CREF(cref = cref) guard(not VariablePointers.containsCref(cref, variables)) algorithm
1139+
Pointer.update(binding_iter_lst, lowerIterator(cref) :: Pointer.access(binding_iter_lst));
1140+
then ();
1141+
else ();
1142+
end match;
1143+
end collectBindingIterators;
1144+
11091145
public
11101146
function lowerComponentReferenceInstNode
11111147
"Adds the pointer to a variable to a component reference. This function needs
@@ -1127,7 +1163,5 @@ public
11271163
end match;
11281164
end lowerComponentReferenceInstNode;
11291165

1130-
protected
1131-
11321166
annotation(__OpenModelica_Interface="backend");
11331167
end NBackendDAE;

OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBSorting.mo

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ public
6060
Integer mode "solve mode index";
6161
end PSEUDO_BUCKET_KEY;
6262

63+
function toString
64+
input PseudoBucketKey key;
65+
output String str = "key: (start[s]: " + intString(key.eqn_start_idx)
66+
+ ", index[a]:" + intString(key.eqn_arr_idx) + ", mode: " + intString(key.mode) + ")";
67+
end toString;
68+
6369
function hash
6470
input PseudoBucketKey key "the key to hash";
6571
input Integer modulo "modulo value";
@@ -90,6 +96,26 @@ public
9096
array<list<Integer>> entwined_arr;
9197
end PSEUDO_BUCKET_ENTWINED;
9298

99+
function toString
100+
input PseudoBucketValue val;
101+
output String str;
102+
algorithm
103+
str := match val
104+
local
105+
PseudoBucketKey k;
106+
PseudoBucketValue v;
107+
case PSEUDO_BUCKET_SINGLE() then "\n\tval: (" + ComponentRef.toString(val.cref_to_solve)
108+
+ ", [" + intString(val.first_comp) + ":" + intString(val.last_comp) +"]" + ")";
109+
case PSEUDO_BUCKET_ENTWINED() algorithm
110+
str := "\n\tentwined:";
111+
for tpl in val.entwined_lst loop
112+
(k, v) := tpl;
113+
str := str + "\n\t" + PseudoBucketKey.toString(k) + " :: " + PseudoBucketValue.toString(v);
114+
end for;
115+
then str;
116+
end match;
117+
end toString;
118+
93119
function addIndices
94120
input output PseudoBucketValue val;
95121
input Integer eqn_scal_idx;
@@ -113,6 +139,11 @@ public
113139
array<Boolean> marks;
114140
end PSEUDO_BUCKET;
115141

142+
function toString
143+
input PseudoBucket bucket;
144+
output String str = UnorderedMap.toString(bucket.bucket, PseudoBucketKey.toString, PseudoBucketValue.toString);
145+
end toString;
146+
116147
function create
117148
"recollects subsets of multi-dimensional equations that have to be solved in the same way.
118149
currently only for loops!"
@@ -317,6 +348,11 @@ public
317348

318349
// recollect array information
319350
bucket := PseudoBucket.create(comps_indices, matching.eqn_to_var, adj.mapping, adj.modes);
351+
352+
if Flags.isSet(Flags.DUMP_SLICE) then
353+
print("--- BUCKETS:\n" + PseudoBucket.toString(bucket) + "\n\n");
354+
end if;
355+
320356
for idx_lst in comps_indices loop
321357
comp_opt := StrongComponent.createPseudo(idx_lst, matching.eqn_to_var, vars, eqns, adj.mapping, adj.modes, bucket);
322358
if Util.isSome(comp_opt) then

OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBRemoveSimpleEquations.mo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ protected
453453
guard(BVariable.isParamOrConst(BVariable.getVarPointer(exp.cref)) or ComponentRef.isTime(exp.cref))
454454
then tpl;
455455

456+
// fail for multidimensional crefs for now
457+
case Expression.CREF()
458+
guard(BVariable.size(BVariable.getVarPointer(exp.cref)) > 1)
459+
then FAILED_CREF_TPL;
460+
456461
// variable found
457462
// 1. not time and not param or const
458463
// 2. less than two previous variables

OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ public
373373
var_arr_idx := UnorderedMap.getSafe(stripped, map);
374374
(var_start, _) := mapping.var_AtS[var_arr_idx];
375375
sizes := ComponentRef.sizes(stripped);
376-
subs := list(Subscript.toExp(sub) for sub in ComponentRef.subscriptsAllWithWholeFlat(cref));
376+
subs := ComponentRef.subscriptsToExpression(cref);
377377
scal_lst := combineFrames2Indices(var_start, sizes, subs, frames, UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual));
378378

379379
if listLength(scal_lst) <> eqn_size then
@@ -413,6 +413,7 @@ public
413413
list<ComponentRef> names;
414414
list<Expression> ranges, subs;
415415
list<list<Expression>> new_subs, new_row_cref_subs;
416+
list<Subscript> evaluated_subs;
416417
list<tuple<ComponentRef, Expression>> frames;
417418
list<ComponentRef> new_row_crefs = {}, new_dep_crefs;
418419
list<list<ComponentRef>> scalar_dependenciesT = {};
@@ -422,19 +423,22 @@ public
422423
frames := List.zip(names, ranges);
423424

424425
// get new subscripts for row cref
425-
subs := list(Subscript.toExp(sub) for sub in ComponentRef.subscriptsAllWithWholeFlat(row_cref));
426+
subs := ComponentRef.subscriptsToExpression(row_cref);
427+
//subs := list(Subscript.toExp(sub) for sub in ComponentRef.subscriptsAllWithWholeFlat(row_cref));
426428
new_row_cref_subs := combineFrames2Exp(subs, frames, UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual));
427429

428430
// reapply new subscripts for each frame location
429431
stripped := ComponentRef.stripSubscriptsAll(row_cref);
430432
for new_subs_single in new_row_cref_subs loop
431-
new_row_crefs := ComponentRef.mergeSubscripts(list(Subscript.fromTypedExp(exp) for exp in new_subs_single), stripped) :: new_row_crefs;
433+
evaluated_subs := list(Subscript.fromTypedExp(exp) for exp in new_subs_single);
434+
new_row_crefs := ComponentRef.mergeSubscripts(evaluated_subs, stripped, false, true) :: new_row_crefs;
432435
end for;
433436

434437
// get the scalar crefs for each cref
435438
if not listEmpty(dependencies) then
436439
for cref in dependencies loop
437-
subs := list(Subscript.toExp(sub) for sub in ComponentRef.subscriptsAllWithWholeFlat(cref));
440+
subs := ComponentRef.subscriptsToExpression(cref);
441+
//subs := list(Subscript.toExp(sub) for sub in ComponentRef.subscriptsAllWithWholeFlat(cref));
438442
new_subs := combineFrames2Exp(subs, frames, UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual));
439443

440444
if listLength(new_subs) <> listLength(new_row_crefs) then
@@ -448,7 +452,8 @@ public
448452
new_dep_crefs := {};
449453
stripped := ComponentRef.stripSubscriptsAll(cref);
450454
for new_subs_single in new_subs loop
451-
new_dep_crefs := ComponentRef.mergeSubscripts(list(Subscript.fromTypedExp(exp) for exp in new_subs_single), stripped) :: new_dep_crefs;
455+
evaluated_subs := list(Subscript.fromTypedExp(exp) for exp in new_subs_single);
456+
new_dep_crefs := ComponentRef.mergeSubscripts(evaluated_subs, stripped, false, true) :: new_dep_crefs;
452457
end for;
453458
scalar_dependenciesT := new_dep_crefs :: scalar_dependenciesT;
454459
end for;
@@ -848,6 +853,7 @@ protected
848853
ComponentRef iterator;
849854
Expression range;
850855
Integer start, step, stop;
856+
list<Expression> local_subs;
851857

852858
// only occurs for non-for-loop equations (no frames to replace)
853859
case {} then {subs};
@@ -860,7 +866,8 @@ protected
860866
UnorderedMap.add(iterator, Expression.INTEGER(index), replacements);
861867
if listEmpty(rest) then
862868
// bottom line, resolve current configuration and create index for it
863-
new_subs := list(SimplifyExp.simplify(Expression.map(sub, function Replacements.applySimpleExp(replacements = replacements))) for sub in subs) :: new_subs;
869+
local_subs := list(SimplifyExp.simplify(Expression.map(sub, function Replacements.applySimpleExp(replacements = replacements))) for sub in subs);
870+
new_subs := listReverse(local_subs) :: new_subs;
864871
else
865872
// not last frame, go deeper
866873
new_subs := combineFrames2Exp(subs, rest, replacements, new_subs);

0 commit comments

Comments
 (0)