Skip to content

Commit

Permalink
[NB] fix array equations in for-loops (#11784)
Browse files Browse the repository at this point in the history
* [NB] fix array equations in for-loops

 - correctly flatten the variables without creating a mode for each one

* [NB] only skip non simple array assignments for new backend
  • Loading branch information
kabdelhak committed Jan 4, 2024
1 parent 29aab86 commit 54cc78f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
5 changes: 2 additions & 3 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBAdjacency.mo
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ public
input list<ComponentRef> unique_dependencies;
protected
// get clean pointers -> type checking fails otherwise
list<ComponentRef> scalarized_dependencies = List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in unique_dependencies));
array<array<Integer>> mode_to_var = modes.mode_to_var;
array<array<ComponentRef>> mode_to_cref = modes.mode_to_cref;
algorithm
Expand All @@ -352,7 +351,7 @@ public
end for;

// create array mode to cref mapping
arrayUpdate(mode_to_cref, eqn_arr_idx, arrayAppend(listArray(scalarized_dependencies), mode_to_cref[eqn_arr_idx]));
arrayUpdate(mode_to_cref, eqn_arr_idx, arrayAppend(listArray(unique_dependencies), mode_to_cref[eqn_arr_idx]));
end update;

function clean
Expand Down Expand Up @@ -712,7 +711,7 @@ public
try
updateRow(eqn_ptr, diffArgs_ptr, st, vars.map, m, mapping, modes, eqn_idx_arr, funcTree);
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + Equation.pointerToString(eqn_ptr)});
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for:\n" + Equation.pointerToString(eqn_ptr)});
fail();
end try;
eqn_idx_arr := eqn_idx_arr + 1;
Expand Down
51 changes: 41 additions & 10 deletions OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ public
list<Integer> scal_lst;
Integer idx;
array<Integer> mode_to_var_row;
list<tuple<ComponentRef, list<Integer>>> scal_tpl_lst;
Integer num_flat_indices;
algorithm
// get iterator frames
(names, ranges) := Iterator.getFrames(iter);
Expand All @@ -426,28 +428,57 @@ public
for i in 1:eqn_size loop
mode_to_var[i] := arrayCreate(listLength(dependencies),-1);
end for;

// create rows
for cref in dependencies loop
scal_lst := getCrefInFrameIndices(cref, frames, mapping, map);
scal_tpl_lst := {};
// 1. scalarize cref and collect all indices per scalar cref
for scal_cref in ComponentRef.scalarizeAll(cref) loop
scal_lst := getCrefInFrameIndices(scal_cref, frames, mapping, map);
scal_tpl_lst := (scal_cref, scal_lst) :: scal_tpl_lst;
end for;

if listLength(scal_lst) <> eqn_size then
// 2. the total number of indices for this dependency has to be equal to
// the length of the equation. E.g following equation of size is 5*3=15
// for i in 1:5 loop
// x[i].y[3:5] = ...
// end for;
// and the cref x[i].y[3:5] is scalarized to {x[i].y[3], x[i].y[4], x[i].y[5]}
// this list evaluated at each iterator position will each give 5 integers
// resulting in the desired 5*3 integers
num_flat_indices := sum(listLength(Util.tuple22(tpl)) for tpl in scal_tpl_lst);
if num_flat_indices <> eqn_size then
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
+ " failed because number of flattened indices " + intString(listLength(scal_lst))
+ " failed because number of flattened indices " + intString(num_flat_indices)
+ " differ from equation size " + intString(eqn_size) + "."});
fail();
end if;

// 3. create the causalization mode and adjacency matrix arrays
// Note that there is only one mode per array variable!
// for the previous example x[i].y[3:5] all 15 scalarized
// indices will get the same mode, but spread on their respective 15 scalar equations
idx := 1;
for var_scal_idx in listReverse(scal_lst) loop
mode_to_var_row := mode_to_var[idx];
mode_to_var_row[mode] := var_scal_idx;
arrayUpdate(mode_to_var_row, mode, var_scal_idx);
indices[idx] := var_scal_idx :: indices[idx];
idx := idx + 1;
for tpl in scal_tpl_lst loop
(_, scal_lst) := tpl;
for var_scal_idx in listReverse(scal_lst) loop
// get the clean pointer to the scalar row to avoid double indexing (meta modelica jank)
mode_to_var_row := mode_to_var[idx];
// set the dependency mode for this scalar equation to the scalar variable
mode_to_var_row[mode] := var_scal_idx;
arrayUpdate(mode_to_var_row, mode, var_scal_idx);
// this is the adjacency matrix row. each dependency cref
// will add exactly one integer to each row belonging to this for-equation
indices[idx] := var_scal_idx :: indices[idx];
idx := idx + 1;
end for;
end for;

// increase mode index
mode := mode + 1;
end for;

// sort
// sort (kabdelhak: is this needed? try to FixMe)
for i in 1:arrayLength(indices) loop
indices[i] := List.sort(UnorderedSet.unique_list(indices[i], Util.id, intEq), intLt);
end for;
Expand Down
6 changes: 6 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,13 @@ algorithm
DAE.ElementSource src;

// convert simple equality of crefs to array equality
// kabdelhak: only do it if all subscripts are simple enough
// will lead to complicated code if not index or whole dim
// and we are better of just using for loops for these
case Equation.EQUALITY(lhs = lhs as Expression.CREF(), rhs = rhs as Expression.CREF())
guard(not Flags.getConfigBool(Flags.NEW_BACKEND)
or (List.all(ComponentRef.subscriptsAllWithWholeFlat(lhs.cref), Subscript.isSimple)
and List.all(ComponentRef.subscriptsAllWithWholeFlat(rhs.cref), Subscript.isSimple)))
algorithm
ty := Type.liftArrayLeftList(eq.ty, dimensions);
lhs := Expression.CREF(ty, lhs.cref);
Expand Down
7 changes: 7 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFSubscript.mo
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ public
end match;
end isWhole;

function isSimple
"used for determining if its simple enough to use for an array equation
in the case of non scalarization (new backend)"
input Subscript sub;
output Boolean isSimple = isIndex(sub) or isWhole(sub);
end isSimple;

function isSliced
input Subscript sub;
output Boolean sliced;
Expand Down

0 comments on commit 54cc78f

Please sign in to comment.