Skip to content

Commit

Permalink
[NB] update array adjacency matrix entries (#9820)
Browse files Browse the repository at this point in the history
  • Loading branch information
kabdelhak committed Dec 1, 2022
1 parent 91e0355 commit a187832
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
54 changes: 51 additions & 3 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBAdjacency.mo
Expand Up @@ -43,6 +43,7 @@ protected
import Dimension = NFDimension;
import Expression = NFExpression;
import FunctionTree = NFFlatten.FunctionTree;
import Subscript = NFSubscript;
import Type = NFType;
import Variable = NFVariable;

Expand Down Expand Up @@ -173,15 +174,62 @@ public
function getVarScalIndices
input Integer arr_idx;
input Mapping mapping;
input list<Subscript> subs;
input list<Dimension> dims;
input Boolean reverse = false;
output list<Integer> scal_indices;
protected
Integer start, length;
function subscriptedIndices
input Integer start;
input Integer length;
input list<Integer> slice;
output list<Integer> scal_indices;
algorithm
scal_indices := List.intRange2(start, start + length - 1);
if not listEmpty(slice) then
scal_indices := List.keepPositions(scal_indices, slice);
end if;
end subscriptedIndices;
algorithm
(start, length) := mapping.var_AtS[arr_idx];
scal_indices := if reverse then
List.intRange2(start + length - 1, start) else
List.intRange2(start, start + length - 1);

scal_indices := match subs
local
Subscript sub;
list<list<Subscript>> subs_lst;
list<Integer> slice = {}, dim_sizes, values;
list<tuple<Integer, Integer>> ranges;

// no subscripts -> create full index list
case {} then subscriptedIndices(start, length, {});

// all subscripts are whole -> create full index list
case _ guard(List.all(subs, Subscript.isWhole)) then subscriptedIndices(start, length, {});

// only one subscript -> apply simple rule
case {sub} algorithm
slice := Subscript.toIndexList(sub, length);
then subscriptedIndices(start, length, slice);

// multiple subscripts -> apply location to index mapping rules
case _ algorithm
subs_lst := Subscript.scalarizeList(subs, dims);
subs_lst := List.combination(subs_lst);
dim_sizes := list(Dimension.size(dim) for dim in dims);
for sub_lst in listReverse(subs_lst) loop
values := list(Subscript.toInteger(s) for s in sub_lst);
ranges := List.zip(dim_sizes, values);
slice := Slice.locationToIndex(ranges, start) :: slice;
end for;
then slice;

else fail();
end match;

if reverse then
scal_indices := listReverse(scal_indices);
end if;
end getVarScalIndices;

protected
Expand Down
8 changes: 7 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo
Expand Up @@ -39,6 +39,7 @@ protected

// NF imports
import ComponentRef = NFComponentRef;
import Dimension = NFDimension;
import Expression = NFExpression;
import Operator = NFOperator;
import SimplifyExp = NFSimplifyExp;
Expand Down Expand Up @@ -291,6 +292,8 @@ public
list<Integer> scal_lst;
Integer idx;
array<Integer> mode_to_var_row;
list<Subscript> subs;
list<Dimension> dims;
algorithm
(eqn_start, eqn_size) := mapping.eqn_AtS[eqn_arr_idx];
indices := arrayCreate(eqn_size, {});
Expand All @@ -302,8 +305,11 @@ public
for cref in dependencies loop
stripped := ComponentRef.stripSubscriptsAll(cref);
var_arr_idx := UnorderedMap.getSafe(stripped, map, sourceInfo());

// build range in reverse, it will be flipped anyway
scal_lst := Mapping.getVarScalIndices(var_arr_idx, mapping, true);
subs := ComponentRef.subscriptsAllWithWholeFlat(cref);
dims := Type.arrayDims(ComponentRef.getSubscriptedType(stripped));
scal_lst := Mapping.getVarScalIndices(var_arr_idx, mapping, subs, dims, true);

if intMod(eqn_size, listLength(scal_lst)) <> 0 then
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
Expand Down
38 changes: 38 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFSubscript.mo
Expand Up @@ -131,6 +131,44 @@ public
end match;
end toInteger;

function toIndexList
input Subscript subscript;
input Integer length;
input Boolean baseZero = true;
output list<Integer> indices;
protected
Integer shift = if baseZero then 1 else 0;
algorithm
indices := match subscript
local
array<Expression> elems;
Integer start, step, stop;

case INDEX() then {toInteger(subscript)-shift};

case WHOLE() then List.intRange2(1-shift,length-shift);

case SLICE(slice = Expression.ARRAY(elements = elems))
then list(Expression.toInteger(e) for e in elems);

case SLICE(slice = Expression.RANGE(
start = Expression.INTEGER(start),
step = SOME(Expression.INTEGER(step)),
stop = Expression.INTEGER(stop)))
then List.intRange3(start-shift, step, stop-shift);

case SLICE(slice = Expression.RANGE(
start = Expression.INTEGER(start),
step = NONE(),
stop = Expression.INTEGER(stop)))
then List.intRange2(start-shift, stop-shift);

else algorithm
Error.assertion(false, getInstanceName() + " got an incorrect subscript type " + toString(subscript) + ".", sourceInfo());
then fail();
end match;
end toIndexList;

protected function isValidIndexType
input Type ty;
output Boolean b = Type.isInteger(ty) or Type.isBoolean(ty) or Type.isEnumeration(ty);
Expand Down

0 comments on commit a187832

Please sign in to comment.