Skip to content

Commit 4974278

Browse files
authored
[NB] adjacency: remove skips in conditions (#12271)
- remove skips in full adjacency matrix for variables in conditions - create utility function for conditions - update debugging information for adjacency creation
1 parent de46d5b commit 4974278

File tree

2 files changed

+163
-116
lines changed

2 files changed

+163
-116
lines changed

OMCompiler/Compiler/NBackEnd/Util/NBAdjacency.mo

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ public
501501
for name in UnorderedMap.keyList(eqns_map) loop
502502
index := UnorderedMap.getSafe(name, eqns_map, sourceInfo());
503503
filtered := Solvability.filter(UnorderedSet.toList(occ[index]), sol[index], vars_map, min, max);
504-
// now run the normal createPseudo pipeline but use dependencies
504+
// upgrade the row and all meta data
505505
upgradeRow(EquationPointers.getEqnAt(eqns, index), index, filtered, dep[index], rep[index], vars_map, adj.m, adj.mapping, adj.modes);
506506
end for;
507507
adj.mT := transposeScalar(adj.m, arrayLength(adj.mapping.var_StA));
@@ -1085,18 +1085,23 @@ public
10851085
Iterator iter = Equation.getForIterator(eqn);
10861086
Type ty = Equation.getType(eqn, true);
10871087
algorithm
1088-
// don't do this for if equations as soon as we properly split them
1089-
if Equation.isAlgorithm(eqn_ptr) or Equation.isIfEquation(eqn_ptr) then
1090-
// algorithm full dependency
1091-
(eqn_scal_idx, eqn_size) := mapping.eqn_AtS[eqn_arr_idx];
1092-
row := Slice.upgradeRowFull(dependencies, map, mapping);
1093-
for i in 0:eqn_size-1 loop
1094-
updateIntegerRow(m, eqn_scal_idx+i, row);
1095-
end for;
1088+
try
1089+
// don't do this for if equations as soon as we properly split them
1090+
if Equation.isAlgorithm(eqn_ptr) or Equation.isIfEquation(eqn_ptr) then
1091+
// algorithm full dependency
1092+
(eqn_scal_idx, eqn_size) := mapping.eqn_AtS[eqn_arr_idx];
1093+
row := Slice.upgradeRowFull(dependencies, map, mapping);
1094+
for i in 0:eqn_size-1 loop
1095+
updateIntegerRow(m, eqn_scal_idx+i, row);
1096+
end for;
1097+
else
1098+
// todo: if, when single equation (needs to be updated for if)
1099+
Slice.upgradeRow(Equation.getEqnName(eqn_ptr), eqn_arr_idx, iter, ty, dependencies, dep, rep, map, m, mapping, modes);
1100+
end if;
10961101
else
1097-
// todo: if, when single equation (needs to be updated for if)
1098-
Slice.upgradeRow(Equation.getEqnName(eqn_ptr), eqn_arr_idx, iter, ty, dependencies, dep, rep, map, m, mapping, modes);
1099-
end if;
1102+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for:\n" + Equation.pointerToString(eqn_ptr)});
1103+
fail();
1104+
end try;
11001105
end upgradeRow;
11011106

11021107
function updateIntegerRow
@@ -1212,6 +1217,24 @@ public
12121217
end if;
12131218
end skip;
12141219

1220+
function removeSkips
1221+
input ComponentRef cref;
1222+
input UnorderedMap<ComponentRef, Dependency> map;
1223+
protected
1224+
Option<Dependency> opt_dep = UnorderedMap.get(cref, map);
1225+
Dependency dep;
1226+
algorithm
1227+
if Util.isSome(opt_dep) then
1228+
SOME(dep) := opt_dep;
1229+
dep.skips := {};
1230+
UnorderedMap.add(cref, dep, map);
1231+
else
1232+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because cref "
1233+
+ ComponentRef.toString(cref) + " was not found in the map."});
1234+
fail();
1235+
end if;
1236+
end removeSkips;
1237+
12151238
function updateList
12161239
input list<ComponentRef> lst;
12171240
input Integer num;
@@ -1233,6 +1256,15 @@ public
12331256
end for;
12341257
end skipList;
12351258

1259+
function removeSkipsList
1260+
input list<ComponentRef> lst;
1261+
input UnorderedMap<ComponentRef, Dependency> map;
1262+
algorithm
1263+
for cref in lst loop
1264+
removeSkips(cref, map);
1265+
end for;
1266+
end removeSkipsList;
1267+
12361268
function addListFull
12371269
"adds a list and applies full dependency"
12381270
input list<ComponentRef> lst;
@@ -1625,16 +1657,16 @@ public
16251657
Solvability.updateList(UnorderedSet.toList(set), Solvability.UNSOLVABLE(), sol_map);
16261658
then set;
16271659

1628-
// variables in conditions and not occuring in both branches are unsolvable
1660+
// variables in conditions are unsolvable and variables not occuring in both branches are implicit
16291661
case Expression.IF() algorithm
16301662
set1 := collectDependencies(exp.trueBranch, map, dep_map, sol_map, rep_set);
16311663
set2 := collectDependencies(exp.falseBranch, map, dep_map, sol_map, rep_set);
1632-
// variables not occuring in both branches will be tagged unsolvable
1664+
// variables not occuring in both branches will be tagged implicit
16331665
diff := UnorderedSet.sym_difference(set1, set2);
1634-
Solvability.updateList(UnorderedSet.toList(diff), Solvability.UNSOLVABLE(), sol_map);
1635-
// variables in conditions are unsolvable
1666+
Solvability.updateList(UnorderedSet.toList(diff), Solvability.IMPLICIT(), sol_map);
1667+
// variables in conditions are unsolvable and their skips have to be removed
16361668
set := collectDependencies(exp.condition, map, dep_map, sol_map, rep_set);
1637-
Solvability.updateList(UnorderedSet.toList(set), Solvability.UNSOLVABLE(), sol_map);
1669+
updateConditionCrefs(UnorderedSet.toList(set), dep_map, sol_map);
16381670
then UnorderedSet.union_list({set, set1, set2}, ComponentRef.hash, ComponentRef.isEqual);
16391671

16401672
// for array constructors replace all iterators (temporarily)
@@ -1753,10 +1785,9 @@ public
17531785
list<UnorderedSet<ComponentRef>> sets1 = {};
17541786
UnorderedSet<ComponentRef> set1, set2, diff;
17551787
algorithm
1756-
// variables in conditions are unsolvable and reduced
1788+
// variables in conditions are unsolvable, reduced and get their skips removed
17571789
set := collectDependencies(body.condition, map, dep_map, sol_map, rep_set);
1758-
Dependency.updateList(UnorderedSet.toList(set), -1, false, dep_map);
1759-
Solvability.updateList(UnorderedSet.toList(set), Solvability.UNSOLVABLE(), sol_map);
1790+
updateConditionCrefs(UnorderedSet.toList(set), dep_map, sol_map);
17601791

17611792
// get variables from 'then' branch
17621793
for eqn in body.then_eqns loop
@@ -1793,10 +1824,9 @@ public
17931824
list<UnorderedSet<ComponentRef>> lst = {}, lst1, lst2;
17941825
list<tuple<UnorderedSet<ComponentRef>, UnorderedSet<ComponentRef>>> tpl_lst = {};
17951826
algorithm
1796-
// variables in conditions are unsolvable and reduced
1827+
// variables in conditions are unsolvable, reduced and get their skips removed
17971828
set := collectDependencies(body.condition, map, dep_map, sol_map, rep_set);
1798-
Dependency.updateList(UnorderedSet.toList(set), -1, false, dep_map);
1799-
Solvability.updateList(UnorderedSet.toList(set), Solvability.UNSOLVABLE(), sol_map);
1829+
updateConditionCrefs(UnorderedSet.toList(set), dep_map, sol_map);
18001830

18011831
// make condition repeat if the body is larger than 1
18021832
if sum(WhenStatement.size(stmt) for stmt in body.when_stmts) > 1 then
@@ -1853,6 +1883,7 @@ public
18531883
case WhenStatement.ASSERT() algorithm
18541884
set1 := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
18551885
set2 := collectDependencies(stmt.condition, map, dep_map, sol_map, rep_set);
1886+
updateConditionCrefs(UnorderedSet.toList(set2), dep_map, sol_map);
18561887
then (set1, set2);
18571888

18581889
else algorithm
@@ -1862,5 +1893,16 @@ public
18621893
end match;
18631894
end collectDependenciesStmt;
18641895

1896+
function updateConditionCrefs
1897+
"variables in conditions are unsolvable, reduced and get their skips removed"
1898+
input list<ComponentRef> crefs;
1899+
input UnorderedMap<ComponentRef, Dependency> dep_map;
1900+
input UnorderedMap<ComponentRef, Solvability> sol_map;
1901+
algorithm
1902+
Dependency.removeSkipsList(crefs, dep_map);
1903+
Dependency.updateList(crefs, -1, false, dep_map);
1904+
Solvability.updateList(crefs, Solvability.UNSOLVABLE(), sol_map);
1905+
end updateConditionCrefs;
1906+
18651907
annotation(__OpenModelica_Interface="backend");
18661908
end NBAdjacency;

OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,108 +1197,113 @@ protected
11971197
Boolean repeated;
11981198
Mode mode;
11991199
algorithm
1200-
// I. resolve the skips
1201-
d := UnorderedMap.getSafe(cref, dep, sourceInfo());
1202-
(start, _) := mapping.eqn_AtS[eqn_arr_idx];
1203-
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips);
1204-
1205-
// get equation and iterator sizes and frames
1206-
body_size := Type.sizeOf(skip_ty);
1207-
iter_size := Iterator.size(iter);
1208-
size := body_size * iter_size;
1209-
(names, ranges) := Iterator.getFrames(iter);
1210-
frames := List.zip(names, ranges);
1211-
1212-
// II. check for regular vs. reduced dimensions
1213-
regulars := Dependency.toBoolean(d);
1214-
if List.all(regulars, Util.id) then
1215-
// II.1 all regular - single dependency per row.
1216-
mode := Mode.create(eqn_name, {cref}, false);
1217-
scalarized := listReverse(ComponentRef.scalarizeAll(cref));
1218-
map3 := UnorderedMap.new<Val2>(ComponentRef.hash, ComponentRef.isEqual);
1219-
for scal in scalarized loop
1220-
UnorderedMap.add(scal, getCrefInFrameIndices(scal, frames, mapping, map), map3);
1221-
end for;
1222-
scal_size := listLength(List.flatten(UnorderedMap.valueList(map3)));
1223-
// either the scalarized list has to be equal in length to the equation or it can be repeated enough times to fit
1224-
if size == scal_size or (UnorderedSet.contains(cref, rep) and intMod(size, scal_size) == 0) then
1225-
for i in 1:size/scal_size loop
1226-
for scal in scalarized loop
1227-
for scal_idx in UnorderedMap.getSafe(scal, map3, sourceInfo()) loop
1228-
addMatrixEntry(m, modes, skip_idx + shift, scal_idx, mode);
1229-
shift := shift + 1;
1200+
try
1201+
// I. resolve the skips
1202+
d := UnorderedMap.getSafe(cref, dep, sourceInfo());
1203+
(start, _) := mapping.eqn_AtS[eqn_arr_idx];
1204+
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips);
1205+
1206+
// get equation and iterator sizes and frames
1207+
body_size := Type.sizeOf(skip_ty);
1208+
iter_size := Iterator.size(iter);
1209+
size := body_size * iter_size;
1210+
(names, ranges) := Iterator.getFrames(iter);
1211+
frames := List.zip(names, ranges);
1212+
1213+
// II. check for regular vs. reduced dimensions
1214+
regulars := Dependency.toBoolean(d);
1215+
if List.all(regulars, Util.id) then
1216+
// II.1 all regular - single dependency per row.
1217+
mode := Mode.create(eqn_name, {cref}, false);
1218+
scalarized := listReverse(ComponentRef.scalarizeAll(cref));
1219+
map3 := UnorderedMap.new<Val2>(ComponentRef.hash, ComponentRef.isEqual);
1220+
for scal in scalarized loop
1221+
UnorderedMap.add(scal, getCrefInFrameIndices(scal, frames, mapping, map), map3);
1222+
end for;
1223+
scal_size := listLength(List.flatten(UnorderedMap.valueList(map3)));
1224+
// either the scalarized list has to be equal in length to the equation or it can be repeated enough times to fit
1225+
if size == scal_size or (UnorderedSet.contains(cref, rep) and intMod(size, scal_size) == 0) then
1226+
for i in 1:size/scal_size loop
1227+
for scal in scalarized loop
1228+
for scal_idx in UnorderedMap.getSafe(scal, map3, sourceInfo()) loop
1229+
addMatrixEntry(m, modes, skip_idx + shift, scal_idx, mode);
1230+
shift := shift + 1;
1231+
end for;
12301232
end for;
12311233
end for;
1232-
end for;
1233-
else
1234-
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " (single dependency) failed because list of scalar variables("
1235-
+ intString(scal_size) + ") " + List.toString(scalarized, ComponentRef.toString)
1236-
+ ", does not fit the equation size " + intString(size) + ".\n"});
1237-
fail();
1238-
end if;
1234+
else
1235+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " (single dependency) failed because list of scalar variables("
1236+
+ intString(scal_size) + ") " + List.toString(scalarized, ComponentRef.toString)
1237+
+ ", does not fit the equation size " + intString(size) + ".\n"});
1238+
fail();
1239+
end if;
12391240

1240-
elseif List.any(regulars, Util.id) then
1241-
// II.2 mixed regularity - find all necessary configurations and add them to a map with a proper key
1242-
// 1. get the cref subscripts and dimensions as well as the equation dimensions (they have to match in length)
1243-
subs := ComponentRef.subscriptsAllWithWholeFlat(cref);
1244-
dims := Type.arrayDims(ComponentRef.getSubscriptedType(cref));
1245-
eq_dims := Type.arrayDims(ty);
1246-
if listLength(subs) == listLength(dims) and listLength(subs) == listLength(regulars) and listLength(subs) == listLength(eq_dims) then
1247-
// 2. create a map that maps a configuration key to the corresponding scalar crefs
1248-
stripped := ComponentRef.stripSubscriptsAll(cref);
1249-
key := arrayCreate(listLength(subs), 0);
1250-
map1 := UnorderedMap.new<Val1>(keyHash, keyEqual);
1251-
resolveReductions(List.zip3(subs, dims, regulars), map1, key, stripped);
1252-
1253-
// 3. create a map that maps a configuration key to the final variable indices
1254-
map2 := UnorderedMap.new<Val2>(keyHash, keyEqual);
1255-
for k in UnorderedMap.keyList(map1) loop
1256-
scalarized := UnorderedMap.getSafe(k, map1, sourceInfo());
1257-
scal_lst := List.flatten(list(getCrefInFrameIndices(scal, frames, mapping, map) for scal in scalarized));
1258-
UnorderedMap.add(k, scal_lst, map2);
1259-
end for;
1241+
elseif List.any(regulars, Util.id) then
1242+
// II.2 mixed regularity - find all necessary configurations and add them to a map with a proper key
1243+
// 1. get the cref subscripts and dimensions as well as the equation dimensions (they have to match in length)
1244+
subs := ComponentRef.subscriptsAllWithWholeFlat(cref);
1245+
dims := Type.arrayDims(ComponentRef.getSubscriptedType(cref));
1246+
eq_dims := Type.arrayDims(ty);
1247+
if listLength(subs) == listLength(dims) and listLength(subs) == listLength(regulars) and listLength(subs) == listLength(eq_dims) then
1248+
// 2. create a map that maps a configuration key to the corresponding scalar crefs
1249+
stripped := ComponentRef.stripSubscriptsAll(cref);
1250+
key := arrayCreate(listLength(subs), 0);
1251+
map1 := UnorderedMap.new<Val1>(keyHash, keyEqual);
1252+
resolveReductions(List.zip3(subs, dims, regulars), map1, key, stripped);
1253+
1254+
// 3. create a map that maps a configuration key to the final variable indices
1255+
map2 := UnorderedMap.new<Val2>(keyHash, keyEqual);
1256+
for k in UnorderedMap.keyList(map1) loop
1257+
scalarized := UnorderedMap.getSafe(k, map1, sourceInfo());
1258+
scal_lst := List.flatten(list(getCrefInFrameIndices(scal, frames, mapping, map) for scal in scalarized));
1259+
UnorderedMap.add(k, scal_lst, map2);
1260+
end for;
12601261

1261-
// 4. iterate over all equation dimensions and use the map to get the correct dependencies
1262-
key := arrayCreate(listLength(subs), 0);
1263-
resolveEquationDimensions(List.zip(eq_dims, regulars), map2, key, m, modes, Mode.create(eqn_name, {cref}, false), Pointer.create(skip_idx));
1264-
else
1265-
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because subscripts, dimensions and dependencies were not of equal length.\n"
1266-
+ "variable subscripts(" + intString(listLength(subs)) + "): " + List.toString(subs, Subscript.toString) + "\n"
1267-
+ "variable dimensions(" + intString(listLength(dims)) + "): " + List.toString(dims, Dimension.toString) + "\n"
1268-
+ "equation dimensions(" + intString(listLength(eq_dims)) + "): " + List.toString(eq_dims, Dimension.toString) + "\n"
1269-
+ "variable dependencies(" + intString(listLength(regulars)) + "): " + List.toString(regulars, boolString) + "\n"});
1270-
fail();
1271-
end if;
1262+
// 4. iterate over all equation dimensions and use the map to get the correct dependencies
1263+
key := arrayCreate(listLength(subs), 0);
1264+
resolveEquationDimensions(List.zip(eq_dims, regulars), map2, key, m, modes, Mode.create(eqn_name, {cref}, false), Pointer.create(skip_idx));
1265+
else
1266+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because subscripts, dimensions and dependencies were not of equal length.\n"
1267+
+ "variable subscripts(" + intString(listLength(subs)) + "): " + List.toString(subs, Subscript.toString) + "\n"
1268+
+ "variable dimensions(" + intString(listLength(dims)) + "): " + List.toString(dims, Dimension.toString) + "\n"
1269+
+ "equation dimensions(" + intString(listLength(eq_dims)) + "): " + List.toString(eq_dims, Dimension.toString) + "\n"
1270+
+ "variable dependencies(" + intString(listLength(regulars)) + "): " + List.toString(regulars, boolString) + "\n"});
1271+
fail();
1272+
end if;
12721273

1273-
else
1274-
// II.3 all reduced - full dependency per row. scalarize and add to all rows of the equation
1275-
repeated := UnorderedSet.contains(cref, rep);
1276-
scalarized := listReverse(ComponentRef.scalarizeAll(cref));
1277-
map3 := UnorderedMap.new<Val2>(ComponentRef.hash, ComponentRef.isEqual);
1278-
for scal in scalarized loop
1279-
UnorderedMap.add(scal, getCrefInFrameIndices(scal, frames, mapping, map), map3);
1280-
end for;
1274+
else
1275+
// II.3 all reduced - full dependency per row. scalarize and add to all rows of the equation
1276+
repeated := UnorderedSet.contains(cref, rep);
1277+
scalarized := listReverse(ComponentRef.scalarizeAll(cref));
1278+
map3 := UnorderedMap.new<Val2>(ComponentRef.hash, ComponentRef.isEqual);
1279+
for scal in scalarized loop
1280+
UnorderedMap.add(scal, getCrefInFrameIndices(scal, frames, mapping, map), map3);
1281+
end for;
12811282

1282-
// if its repeated, use the same cref always
1283-
if repeated then
1284-
mode := Mode.create(eqn_name, {cref}, false);
1285-
end if;
1283+
// if its repeated, use the same cref always
1284+
if repeated then
1285+
mode := Mode.create(eqn_name, {cref}, false);
1286+
end if;
12861287

1287-
for i in skip_idx:iter_size:skip_idx+size-iter_size loop
1288-
shift := 0;
1289-
for scal in scalarized loop
1290-
// if its not repeated use local cref
1291-
if not repeated then
1292-
mode := Mode.create(eqn_name, {scal}, true);
1293-
end if;
1294-
for scal_idx in UnorderedMap.getSafe(scal, map3, sourceInfo()) loop
1295-
if intMod(shift, iter_size) == 0 then shift := 0; end if;
1296-
addMatrixEntry(m, modes, i + shift, scal_idx, mode);
1297-
shift := shift + 1;
1288+
for i in skip_idx:iter_size:skip_idx+size-iter_size loop
1289+
shift := 0;
1290+
for scal in scalarized loop
1291+
// if its not repeated use local cref
1292+
if not repeated then
1293+
mode := Mode.create(eqn_name, {scal}, true);
1294+
end if;
1295+
for scal_idx in UnorderedMap.getSafe(scal, map3, sourceInfo()) loop
1296+
if intMod(shift, iter_size) == 0 then shift := 0; end if;
1297+
addMatrixEntry(m, modes, i + shift, scal_idx, mode);
1298+
shift := shift + 1;
1299+
end for;
12981300
end for;
12991301
end for;
1300-
end for;
1301-
end if;
1302+
end if;
1303+
else
1304+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + ComponentRef.toString(cref) + "."});
1305+
fail();
1306+
end try;
13021307
end resolveDependency;
13031308

13041309
function resolveEquationDimensions

0 commit comments

Comments
 (0)